Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/vision/models/unet.py
781 views
1
from ...torch_core import *
2
from ...layers import *
3
from ...callbacks.hooks import *
4
5
__all__ = ['DynamicUnet', 'UnetBlock']
6
7
def _get_sfs_idxs(sizes:Sizes) -> List[int]:
8
"Get the indexes of the layers where the size of the activation changes."
9
feature_szs = [size[-1] for size in sizes]
10
sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
11
if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
12
return sfs_idxs
13
14
class UnetBlock(Module):
15
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
16
def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
17
self_attention:bool=False, **kwargs):
18
self.hook = hook
19
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
20
self.bn = batchnorm_2d(x_in_c)
21
ni = up_in_c//2 + x_in_c
22
nf = ni if final_div else ni//2
23
self.conv1 = conv_layer(ni, nf, leaky=leaky, **kwargs)
24
self.conv2 = conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
25
self.relu = relu(leaky=leaky)
26
27
def forward(self, up_in:Tensor) -> Tensor:
28
s = self.hook.stored
29
up_out = self.shuf(up_in)
30
ssh = s.shape[-2:]
31
if ssh != up_out.shape[-2:]:
32
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
33
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
34
return self.conv2(self.conv1(cat_x))
35
36
37
class DynamicUnet(SequentialEx):
38
"Create a U-Net from a given architecture."
39
def __init__(self, encoder:nn.Module, n_classes:int, img_size:Tuple[int,int]=(256,256), blur:bool=False, blur_final=True, self_attention:bool=False,
40
y_range:Optional[Tuple[float,float]]=None,
41
last_cross:bool=True, bottle:bool=False, **kwargs):
42
imsize = img_size
43
sfs_szs = model_sizes(encoder, size=imsize)
44
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
45
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
46
x = dummy_eval(encoder, imsize).detach()
47
48
ni = sfs_szs[-1][1]
49
middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs),
50
conv_layer(ni*2, ni, **kwargs)).eval()
51
x = middle_conv(x)
52
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
53
54
for i,idx in enumerate(sfs_idxs):
55
not_final = i!=len(sfs_idxs)-1
56
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
57
do_blur = blur and (not_final or blur_final)
58
sa = self_attention and (i==len(sfs_idxs)-3)
59
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
60
**kwargs).eval()
61
layers.append(unet_block)
62
x = unet_block(x)
63
64
ni = x.shape[1]
65
if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
66
x = PixelShuffle_ICNR(ni)(x)
67
if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest')))
68
if last_cross:
69
layers.append(MergeLayer(dense=True))
70
ni += in_channels(encoder)
71
layers.append(res_block(ni, bottle=bottle, **kwargs))
72
layers += [conv_layer(ni, n_classes, ks=1, use_activ=False, **kwargs)]
73
if y_range is not None: layers.append(SigmoidRange(*y_range))
74
super().__init__(*layers)
75
76
def __del__(self):
77
if hasattr(self, "sfs"): self.sfs.remove()
78
79
80