Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/image.py
781 views
1
"`Image` provides support to convert, transform and show images"
2
from ..torch_core import *
3
from ..basic_data import *
4
from ..layers import MSELossFlat
5
from io import BytesIO
6
import PIL
7
8
__all__ = ['PIL', 'Image', 'ImageBBox', 'ImageSegment', 'ImagePoints', 'FlowField', 'RandTransform', 'TfmAffine', 'TfmCoord',
9
'TfmCrop', 'TfmLighting', 'TfmPixel', 'Transform', 'bb2hw', 'image2np', 'open_image', 'open_mask', 'tis2hw',
10
'pil2tensor', 'scale_flow', 'show_image', 'CoordFunc', 'TfmList', 'open_mask_rle', 'rle_encode',
11
'rle_decode', 'ResizeMethod', 'plot_flat', 'plot_multi', 'show_multi', 'show_all']
12
13
ResizeMethod = IntEnum('ResizeMethod', 'CROP PAD SQUISH NO')
14
def pil2tensor(image:Union[NPImage,NPArray],dtype:np.dtype)->TensorImage:
15
"Convert PIL style `image` array to torch style image tensor."
16
a = np.asarray(image)
17
if a.ndim==2 : a = np.expand_dims(a,2)
18
a = np.transpose(a, (1, 0, 2))
19
a = np.transpose(a, (2, 1, 0))
20
return torch.from_numpy(a.astype(dtype, copy=False) )
21
22
def image2np(image:Tensor)->np.ndarray:
23
"Convert from torch style `image` to numpy/matplotlib style."
24
res = image.cpu().permute(1,2,0).numpy()
25
return res[...,0] if res.shape[2]==1 else res
26
27
def bb2hw(a:Collection[int])->np.ndarray:
28
"Convert bounding box points from (width,height,center) to (height,width,top,left)."
29
return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]])
30
31
def tis2hw(size:Union[int,TensorImageSize]) -> Tuple[int,int]:
32
"Convert `int` or `TensorImageSize` to (height,width) of an image."
33
if type(size) is str: raise RuntimeError("Expected size to be an int or a tuple, got a string.")
34
return listify(size, 2) if isinstance(size, int) else listify(size[-2:],2)
35
36
def _draw_outline(o:Patch, lw:int):
37
"Outline bounding box onto image `Patch`."
38
o.set_path_effects([patheffects.Stroke(
39
linewidth=lw, foreground='black'), patheffects.Normal()])
40
41
def _draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):
42
"Draw bounding box on `ax`."
43
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
44
_draw_outline(patch, 4)
45
if text is not None:
46
patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
47
_draw_outline(patch,1)
48
49
def _get_default_args(func:Callable):
50
return {k: v.default
51
for k, v in inspect.signature(func).parameters.items()
52
if v.default is not inspect.Parameter.empty}
53
54
@dataclass
55
class FlowField():
56
"Wrap together some coords `flow` with a `size`."
57
size:Tuple[int,int]
58
flow:Tensor
59
60
CoordFunc = Callable[[FlowField, ArgStar, KWArgs], LogitTensorImage]
61
62
class Image(ItemBase):
63
"Support applying transforms to image data in `px`."
64
def __init__(self, px:Tensor):
65
self._px = px
66
self._logit_px=None
67
self._flow=None
68
self._affine_mat=None
69
self.sample_kwargs = {}
70
71
def set_sample(self, **kwargs)->'ImageBase':
72
"Set parameters that control how we `grid_sample` the image after transforms are applied."
73
self.sample_kwargs = kwargs
74
return self
75
76
def clone(self):
77
"Mimic the behavior of torch.clone for `Image` objects."
78
return self.__class__(self.px.clone())
79
80
@property
81
def shape(self)->Tuple[int,int,int]: return self._px.shape
82
@property
83
def size(self)->Tuple[int,int]: return self.shape[-2:]
84
@property
85
def device(self)->torch.device: return self._px.device
86
87
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}'
88
def _repr_png_(self): return self._repr_image_format('png')
89
def _repr_jpeg_(self): return self._repr_image_format('jpeg')
90
91
def _repr_image_format(self, format_str):
92
with BytesIO() as str_buffer:
93
plt.imsave(str_buffer, image2np(self.px), format=format_str)
94
return str_buffer.getvalue()
95
96
def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None,
97
size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None,
98
mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True,
99
is_x:bool=True, x_frames:int=1, y_frames:int=1)->TensorImage:
100
"Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args."
101
if not (tfms or xtra or size): return self
102
103
if size is not None and isinstance(size, int):
104
num_frames = x_frames if is_x else y_frames
105
if num_frames > 1:
106
size = (size, size*num_frames)
107
108
tfms = listify(tfms)
109
xtra = ifnone(xtra, {})
110
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP
111
resize_method = ifnone(resize_method, default_rsz)
112
if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)
113
tfms = sorted(tfms, key=lambda o: o.tfm.order)
114
if do_resolve: _resolve_tfms(tfms)
115
x = self.clone()
116
x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)
117
if size is not None:
118
crop_target = _get_crop_target(size, mult=mult)
119
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
120
target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP))
121
x.resize(target)
122
elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target)
123
else: size = x.size
124
size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]
125
for tfm in tfms:
126
if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])
127
elif tfm in size_tfms:
128
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
129
x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)
130
else: x = tfm(x)
131
return x.refresh()
132
133
def refresh(self)->None:
134
"Apply any logit, flow, or affine transfers that have been sent to the `Image`."
135
if self._logit_px is not None:
136
self._px = self._logit_px.sigmoid_()
137
self._logit_px = None
138
if self._affine_mat is not None or self._flow is not None:
139
self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)
140
self.sample_kwargs = {}
141
self._flow = None
142
return self
143
144
def save(self, fn:PathOrStr):
145
"Save the image to `fn`."
146
x = image2np(self.data*255).astype(np.uint8)
147
PIL.Image.fromarray(x).save(fn)
148
149
@property
150
def px(self)->TensorImage:
151
"Get the tensor pixel buffer."
152
self.refresh()
153
return self._px
154
@px.setter
155
def px(self,v:TensorImage)->None:
156
"Set the pixel buffer to `v`."
157
self._px=v
158
159
@property
160
def flow(self)->FlowField:
161
"Access the flow-field grid after applying queued affine transforms."
162
if self._flow is None:
163
self._flow = _affine_grid(self.shape)
164
if self._affine_mat is not None:
165
self._flow = _affine_mult(self._flow,self._affine_mat)
166
self._affine_mat = None
167
return self._flow
168
169
@flow.setter
170
def flow(self,v:FlowField): self._flow=v
171
172
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any):
173
"Equivalent to `image = sigmoid(func(logit(image)))`."
174
self.logit_px = func(self.logit_px, *args, **kwargs)
175
return self
176
177
def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':
178
"Equivalent to `image.px = func(image.px)`."
179
self.px = func(self.px, *args, **kwargs)
180
return self
181
182
def coord(self, func:CoordFunc, *args, **kwargs)->'Image':
183
"Equivalent to `image.flow = func(image.flow, image.size)`."
184
self.flow = func(self.flow, *args, **kwargs)
185
return self
186
187
def affine(self, func:AffineFunc, *args, **kwargs)->'Image':
188
"Equivalent to `image.affine_mat = image.affine_mat @ func()`."
189
m = tensor(func(*args, **kwargs)).to(self.device)
190
self.affine_mat = self.affine_mat @ m
191
return self
192
193
def resize(self, size:Union[int,TensorImageSize])->'Image':
194
"Resize the image to `size`, size can be a single int."
195
assert self._flow is None
196
if isinstance(size, int): size=(self.shape[0], size, size)
197
if tuple(size)==tuple(self.shape): return self
198
self.flow = _affine_grid(size)
199
return self
200
201
@property
202
def affine_mat(self)->AffineMatrix:
203
"Get the affine matrix that will be applied by `refresh`."
204
if self._affine_mat is None:
205
self._affine_mat = torch.eye(3).to(self.device)
206
return self._affine_mat
207
@affine_mat.setter
208
def affine_mat(self,v)->None: self._affine_mat=v
209
210
@property
211
def logit_px(self)->LogitTensorImage:
212
"Get logit(image.px)."
213
if self._logit_px is None: self._logit_px = logit_(self.px)
214
return self._logit_px
215
@logit_px.setter
216
def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v
217
218
@property
219
def data(self)->TensorImage:
220
"Return this images pixels as a tensor."
221
return self.px
222
223
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
224
cmap:str=None, y:Any=None, **kwargs):
225
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
226
cmap = ifnone(cmap, defaults.cmap)
227
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
228
if y is not None: y.show(ax=ax, **kwargs)
229
if title is not None: ax.set_title(title)
230
231
class ImageSegment(Image):
232
"Support applying transforms to segmentation masks data in `px`."
233
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self
234
235
def refresh(self):
236
self.sample_kwargs['mode'] = 'nearest'
237
return super().refresh()
238
239
@property
240
def data(self)->TensorImage:
241
"Return this image pixels as a `LongTensor`."
242
return self.px.long()
243
244
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
245
cmap:str='tab20', alpha:float=0.5, **kwargs):
246
"Show the `ImageSegment` on `ax`."
247
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize,
248
interpolation='nearest', alpha=alpha, vmin=0, **kwargs)
249
if title: ax.set_title(title)
250
251
def reconstruct(self, t:Tensor): return ImageSegment(t)
252
253
class ImagePoints(Image):
254
"Support applying transforms to a `flow` of points."
255
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True):
256
if scale: flow = scale_flow(flow)
257
if y_first: flow.flow = flow.flow.flip(1)
258
self._flow = flow
259
self._affine_mat = None
260
self.flow_func = []
261
self.sample_kwargs = {}
262
self.transformed = False
263
self.loss_func = MSELossFlat()
264
265
def clone(self):
266
"Mimic the behavior of torch.clone for `ImagePoints` objects."
267
return self.__class__(FlowField(self.size, self.flow.flow.clone()), scale=False, y_first=False)
268
269
@property
270
def shape(self)->Tuple[int,int,int]: return (1, *self._flow.size)
271
@property
272
def size(self)->Tuple[int,int]: return self._flow.size
273
@size.setter
274
def size(self, sz:int): self._flow.size=sz
275
@property
276
def device(self)->torch.device: return self._flow.flow.device
277
278
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.size)}'
279
def _repr_image_format(self, format_str): return None
280
281
@property
282
def flow(self)->FlowField:
283
"Access the flow-field grid after applying queued affine and coord transforms."
284
if self._affine_mat is not None:
285
self._flow = _affine_inv_mult(self._flow, self._affine_mat)
286
self._affine_mat = None
287
self.transformed = True
288
if len(self.flow_func) != 0:
289
for f in self.flow_func[::-1]: self._flow = f(self._flow)
290
self.transformed = True
291
self.flow_func = []
292
return self._flow
293
294
@flow.setter
295
def flow(self,v:FlowField): self._flow=v
296
297
def coord(self, func:CoordFunc, *args, **kwargs)->'ImagePoints':
298
"Put `func` with `args` and `kwargs` in `self.flow_func` for later."
299
if 'invert' in kwargs: kwargs['invert'] = True
300
else: warn(f"{func.__name__} isn't implemented for {self.__class__}.")
301
self.flow_func.append(partial(func, *args, **kwargs))
302
return self
303
304
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'ImagePoints': return self
305
306
def pixel(self, func:PixelFunc, *args, **kwargs)->'ImagePoints':
307
"Equivalent to `self = func_flow(self)`."
308
self = func(self, *args, **kwargs)
309
self.transformed=True
310
return self
311
312
def refresh(self) -> 'ImagePoints':
313
return self
314
315
def resize(self, size:Union[int,TensorImageSize]) -> 'ImagePoints':
316
"Resize the image to `size`, size can be a single int."
317
if isinstance(size, int): size=(1, size, size)
318
self._flow.size = size[1:]
319
return self
320
321
@property
322
def data(self)->Tensor:
323
"Return the points associated to this object."
324
flow = self.flow #This updates flow before we test if some transforms happened
325
if self.transformed:
326
if 'remove_out' not in self.sample_kwargs or self.sample_kwargs['remove_out']:
327
flow = _remove_points_out(flow)
328
self.transformed=False
329
return flow.flow.flip(1)
330
331
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, **kwargs):
332
"Show the `ImagePoints` on `ax`."
333
if ax is None: _,ax = plt.subplots(figsize=figsize)
334
pnt = scale_flow(FlowField(self.size, self.data), to_unit=False).flow.flip(1)
335
params = {'s': 10, 'marker': '.', 'c': 'r', **kwargs}
336
ax.scatter(pnt[:, 0], pnt[:, 1], **params)
337
if hide_axis: ax.axis('off')
338
if title: ax.set_title(title)
339
340
class ImageBBox(ImagePoints):
341
"Support applying transforms to a `flow` of bounding boxes."
342
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,
343
classes:dict=None, pad_idx:int=0):
344
super().__init__(flow, scale, y_first)
345
self.pad_idx = pad_idx
346
if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):
347
labels = array([Category(l,classes[l]) for l in labels])
348
self.labels = labels
349
350
def clone(self) -> 'ImageBBox':
351
"Mimic the behavior of torch.clone for `Image` objects."
352
flow = FlowField(self.size, self.flow.flow.clone())
353
return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)
354
355
@classmethod
356
def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None,
357
pad_idx:int=0, scale:bool=True)->'ImageBBox':
358
"Create an ImageBBox object from `bboxes`."
359
if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes])
360
bboxes = tensor(bboxes).float()
361
tr_corners = torch.cat([bboxes[:,0][:,None], bboxes[:,3][:,None]], 1)
362
bl_corners = bboxes[:,1:3].flip(1)
363
bboxes = torch.cat([bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1)
364
flow = FlowField((h,w), bboxes.view(-1,2))
365
return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)
366
367
def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]:
368
bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1)
369
mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0]
370
bboxes = torch.cat([mins, maxes], 1)
371
mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0)
372
if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx])
373
res = bboxes[mask]
374
if self.labels is None: return res,None
375
return res, self.labels[to_np(mask).astype(bool)]
376
377
@property
378
def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:
379
bboxes,lbls = self._compute_boxes()
380
lbls = np.array([o.data for o in lbls]) if lbls is not None else None
381
return bboxes if lbls is None else (bboxes, lbls)
382
383
def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
384
color:str='white', **kwargs):
385
"Show the `ImageBBox` on `ax`."
386
if ax is None: _,ax = plt.subplots(figsize=figsize)
387
bboxes, lbls = self._compute_boxes()
388
h,w = self.flow.size
389
bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long()
390
for i, bbox in enumerate(bboxes):
391
if lbls is not None: text = str(lbls[i])
392
else: text=None
393
_draw_rect(ax, bb2hw(bbox), text=text, color=color)
394
395
def open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', cls:type=Image,
396
after_open:Callable=None)->Image:
397
"Return `Image` object created from image in file `fn`."
398
with warnings.catch_warnings():
399
warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
400
x = PIL.Image.open(fn).convert(convert_mode)
401
if after_open: x = after_open(x)
402
x = pil2tensor(x,np.float32)
403
if div: x.div_(255)
404
return cls(x)
405
406
def open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None)->ImageSegment:
407
"Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255."
408
return open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open)
409
410
def open_mask_rle(mask_rle:str, shape:Tuple[int, int])->ImageSegment:
411
"Return `ImageSegment` object create from run-length encoded string in `mask_lre` with size in `shape`."
412
x = FloatTensor(rle_decode(str(mask_rle), shape).astype(np.uint8))
413
x = x.view(shape[1], shape[0], -1)
414
return ImageSegment(x.permute(2,0,1))
415
416
def rle_encode(img:NPArrayMask)->str:
417
"Return run-length encoding string from `img`."
418
pixels = np.concatenate([[0], img.flatten() , [0]])
419
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
420
runs[1::2] -= runs[::2]
421
return ' '.join(str(x) for x in runs)
422
423
def rle_decode(mask_rle:str, shape:Tuple[int,int])->NPArrayMask:
424
"Return an image array from run-length encoded string `mask_rle` with `shape`."
425
s = mask_rle.split()
426
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
427
starts -= 1
428
ends = starts + lengths
429
img = np.zeros(shape[0]*shape[1], dtype=np.uint)
430
for low, up in zip(starts, ends): img[low:up] = 1
431
return img.reshape(shape)
432
433
def show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary',
434
alpha:float=None, **kwargs)->plt.Axes:
435
"Display `Image` in notebook."
436
if ax is None: fig,ax = plt.subplots(figsize=figsize)
437
ax.imshow(image2np(img.data), cmap=cmap, alpha=alpha, **kwargs)
438
if hide_axis: ax.axis('off')
439
return ax
440
441
def scale_flow(flow, to_unit=True):
442
"Scale the coords in `flow` to -1/1 or the image size depending on `to_unit`."
443
s = tensor([flow.size[0]/2,flow.size[1]/2])[None]
444
if to_unit: flow.flow = flow.flow/s-1
445
else: flow.flow = (flow.flow+1)*s
446
return flow
447
448
def _remove_points_out(flow:FlowField):
449
pad_mask = (flow.flow[:,0] >= -1) * (flow.flow[:,0] <= 1) * (flow.flow[:,1] >= -1) * (flow.flow[:,1] <= 1)
450
flow.flow = flow.flow[pad_mask]
451
return flow
452
453
class Transform():
454
"Utility class for adding probability and wrapping support to transform `func`."
455
_wrap=None
456
order=0
457
def __init__(self, func:Callable, order:Optional[int]=None):
458
"Create a transform for `func` and assign it an priority `order`, attach to `Image` class."
459
if order is not None: self.order=order
460
self.func=func
461
self.func.__name__ = func.__name__[1:] #To remove the _ that begins every transform function.
462
functools.update_wrapper(self, self.func)
463
self.func.__annotations__['return'] = Image
464
self.params = copy(func.__annotations__)
465
self.def_args = _get_default_args(func)
466
setattr(Image, func.__name__,
467
lambda x, *args, **kwargs: self.calc(x, *args, **kwargs))
468
469
def __call__(self, *args:Any, p:float=1., is_random:bool=True, use_on_y:bool=True, **kwargs:Any)->Image:
470
"Calc now if `args` passed; else create a transform called prob `p` if `random`."
471
if args: return self.calc(*args, **kwargs)
472
else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p)
473
474
def calc(self, x:Image, *args:Any, **kwargs:Any)->Image:
475
"Apply to image `x`, wrapping it if necessary."
476
if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)
477
else: return self.func(x, *args, **kwargs)
478
479
@property
480
def name(self)->str: return self.__class__.__name__
481
482
def __repr__(self)->str: return f'{self.name} ({self.func.__name__})'
483
484
@dataclass
485
class RandTransform():
486
"Wrap `Transform` to add randomized execution."
487
tfm:Transform
488
kwargs:dict
489
p:float=1.0
490
resolved:dict = field(default_factory=dict)
491
do_run:bool = True
492
is_random:bool = True
493
use_on_y:bool = True
494
def __post_init__(self): functools.update_wrapper(self, self.tfm)
495
496
def resolve(self)->None:
497
"Bind any random variables in the transform."
498
if not self.is_random:
499
self.resolved = {**self.tfm.def_args, **self.kwargs}
500
return
501
502
self.resolved = {}
503
# for each param passed to tfm...
504
for k,v in self.kwargs.items():
505
# ...if it's annotated, call that fn...
506
if k in self.tfm.params:
507
rand_func = self.tfm.params[k]
508
self.resolved[k] = rand_func(*listify(v))
509
# ...otherwise use the value directly
510
else: self.resolved[k] = v
511
# use defaults for any args not filled in yet
512
for k,v in self.tfm.def_args.items():
513
if k not in self.resolved: self.resolved[k]=v
514
# anything left over must be callable without params
515
for k,v in self.tfm.params.items():
516
if k not in self.resolved and k!='return': self.resolved[k]=v()
517
518
self.do_run = rand_bool(self.p)
519
520
@property
521
def order(self)->int: return self.tfm.order
522
523
def __call__(self, x:Image, *args, **kwargs)->Image:
524
"Randomly execute our tfm on `x`."
525
return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x
526
527
def _resolve_tfms(tfms:TfmList):
528
"Resolve every tfm in `tfms`."
529
for f in listify(tfms): f.resolve()
530
531
def _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage:
532
"Resample pixels in `coords` from `x` by `mode`, with `padding_mode` in ('reflection','border','zeros')."
533
coords = coords.flow.permute(0, 3, 1, 2).contiguous().permute(0, 2, 3, 1) # optimize layout for grid_sample
534
if mode=='bilinear': # hack to get smoother downwards resampling
535
mn,mx = coords.min(),coords.max()
536
# max amount we're affine zooming by (>1 means zooming in)
537
z = 1/(mx-mn).item()*2
538
# amount we're resizing by, with 100% extra margin
539
d = min(x.shape[1]/coords.shape[1], x.shape[2]/coords.shape[2])/2
540
# If we're resizing up by >200%, and we're zooming less than that, interpolate first
541
if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0]
542
return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]
543
544
def _affine_grid(size:TensorImageSize)->FlowField:
545
size = ((1,)+size)
546
N, C, H, W = size
547
grid = FloatTensor(N, H, W, 2)
548
linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])
549
grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])
550
linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])
551
grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])
552
return FlowField(size[2:], grid)
553
554
def _affine_mult(c:FlowField,m:AffineMatrix)->FlowField:
555
"Multiply `c` by `m` - can adjust for rectangular shaped `c`."
556
if m is None: return c
557
size = c.flow.size()
558
h,w = c.size
559
m[0,1] *= h/w
560
m[1,0] *= w/h
561
c.flow = c.flow.view(-1,2)
562
c.flow = torch.addmm(m[:2,2], c.flow, m[:2,:2].t()).view(size)
563
return c
564
565
def _affine_inv_mult(c, m):
566
"Applies the inverse affine transform described in `m` to `c`."
567
size = c.flow.size()
568
h,w = c.size
569
m[0,1] *= h/w
570
m[1,0] *= w/h
571
c.flow = c.flow.view(-1,2)
572
a = torch.inverse(m[:2,:2].t())
573
c.flow = torch.mm(c.flow - m[:2,2], a).view(size)
574
return c
575
576
class TfmAffine(Transform):
577
"Decorator for affine tfm funcs."
578
order,_wrap = 5,'affine'
579
class TfmPixel(Transform):
580
"Decorator for pixel tfm funcs."
581
order,_wrap = 10,'pixel'
582
class TfmCoord(Transform):
583
"Decorator for coord tfm funcs."
584
order,_wrap = 4,'coord'
585
class TfmCrop(TfmPixel):
586
"Decorator for crop tfm funcs."
587
order=99
588
class TfmLighting(Transform):
589
"Decorator for lighting tfm funcs."
590
order,_wrap = 8,'lighting'
591
592
def _round_multiple(x:int, mult:int=None)->int:
593
"Calc `x` to nearest multiple of `mult`."
594
return (int(x/mult+0.5)*mult) if mult is not None else x
595
596
def _get_crop_target(target_px:Union[int,TensorImageSize], mult:int=None)->Tuple[int,int]:
597
"Calc crop shape of `target_px` to nearest multiple of `mult`."
598
target_r,target_c = tis2hw(target_px)
599
return _round_multiple(target_r,mult),_round_multiple(target_c,mult)
600
601
def _get_resize_target(img, crop_target, do_crop=False)->TensorImageSize:
602
"Calc size of `img` to fit in `crop_target` - adjust based on `do_crop`."
603
if crop_target is None: return None
604
ch,r,c = img.shape
605
target_r,target_c = crop_target
606
ratio = (min if do_crop else max)(r/target_r, c/target_c)
607
return ch,int(round(r/ratio)),int(round(c/ratio)) #Sometimes those are numpy numbers and round doesn't return an int.
608
609
def plot_flat(r, c, figsize):
610
"Shortcut for `enumerate(subplots.flatten())`"
611
return enumerate(plt.subplots(r, c, figsize=figsize)[1].flatten())
612
613
def plot_multi(func:Callable[[int,int,plt.Axes],None], r:int=1, c:int=1, figsize:Tuple=(12,6)):
614
"Call `func` for every combination of `r,c` on a subplot"
615
axes = plt.subplots(r, c, figsize=figsize)[1]
616
for i in range(r):
617
for j in range(c): func(i,j,axes[i,j])
618
619
def show_multi(func:Callable[[int,int],Image], r:int=1, c:int=1, figsize:Tuple=(9,9)):
620
"Call `func(i,j).show(ax)` for every combination of `r,c`"
621
plot_multi(lambda i,j,ax: func(i,j).show(ax), r, c, figsize=figsize)
622
623
def show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12,6)):
624
"Show all `imgs` using `r` rows"
625
imgs = listify(imgs)
626
if c is None: c = len(imgs)//r
627
for i,ax in plot_flat(r,c,figsize): imgs[i].show(ax)
628
629