Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/levit.py
649 views
1
from math import ceil
2
3
import torch
4
from torch import nn, einsum
5
import torch.nn.functional as F
6
7
from einops import rearrange, repeat
8
from einops.layers.torch import Rearrange
9
10
# helpers
11
12
def exists(val):
13
return val is not None
14
15
def default(val, d):
16
return val if exists(val) else d
17
18
def cast_tuple(val, l = 3):
19
val = val if isinstance(val, tuple) else (val,)
20
return (*val, *((val[-1],) * max(l - len(val), 0)))
21
22
def always(val):
23
return lambda *args, **kwargs: val
24
25
# classes
26
27
class FeedForward(nn.Module):
28
def __init__(self, dim, mult, dropout = 0.):
29
super().__init__()
30
self.net = nn.Sequential(
31
nn.Conv2d(dim, dim * mult, 1),
32
nn.Hardswish(),
33
nn.Dropout(dropout),
34
nn.Conv2d(dim * mult, dim, 1),
35
nn.Dropout(dropout)
36
)
37
def forward(self, x):
38
return self.net(x)
39
40
class Attention(nn.Module):
41
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
42
super().__init__()
43
inner_dim_key = dim_key * heads
44
inner_dim_value = dim_value * heads
45
dim_out = default(dim_out, dim)
46
47
self.heads = heads
48
self.scale = dim_key ** -0.5
49
50
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
51
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
52
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
53
54
self.attend = nn.Softmax(dim = -1)
55
self.dropout = nn.Dropout(dropout)
56
57
out_batch_norm = nn.BatchNorm2d(dim_out)
58
nn.init.zeros_(out_batch_norm.weight)
59
60
self.to_out = nn.Sequential(
61
nn.GELU(),
62
nn.Conv2d(inner_dim_value, dim_out, 1),
63
out_batch_norm,
64
nn.Dropout(dropout)
65
)
66
67
# positional bias
68
69
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
70
71
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
72
k_range = torch.arange(fmap_size)
73
74
q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
75
k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)
76
77
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
78
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
79
80
x_rel, y_rel = rel_pos.unbind(dim = -1)
81
pos_indices = (x_rel * fmap_size) + y_rel
82
83
self.register_buffer('pos_indices', pos_indices)
84
85
def apply_pos_bias(self, fmap):
86
bias = self.pos_bias(self.pos_indices)
87
bias = rearrange(bias, 'i j h -> () h i j')
88
return fmap + (bias / self.scale)
89
90
def forward(self, x):
91
b, n, *_, h = *x.shape, self.heads
92
93
q = self.to_q(x)
94
y = q.shape[2]
95
96
qkv = (q, self.to_k(x), self.to_v(x))
97
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
98
99
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
100
101
dots = self.apply_pos_bias(dots)
102
103
attn = self.attend(dots)
104
attn = self.dropout(attn)
105
106
out = einsum('b h i j, b h j d -> b h i d', attn, v)
107
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
108
return self.to_out(out)
109
110
class Transformer(nn.Module):
111
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
112
super().__init__()
113
dim_out = default(dim_out, dim)
114
self.layers = nn.ModuleList([])
115
self.attn_residual = (not downsample) and dim == dim_out
116
117
for _ in range(depth):
118
self.layers.append(nn.ModuleList([
119
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
120
FeedForward(dim_out, mlp_mult, dropout = dropout)
121
]))
122
def forward(self, x):
123
for attn, ff in self.layers:
124
attn_res = (x if self.attn_residual else 0)
125
x = attn(x) + attn_res
126
x = ff(x) + x
127
return x
128
129
class LeViT(nn.Module):
130
def __init__(
131
self,
132
*,
133
image_size,
134
num_classes,
135
dim,
136
depth,
137
heads,
138
mlp_mult,
139
stages = 3,
140
dim_key = 32,
141
dim_value = 64,
142
dropout = 0.,
143
num_distill_classes = None
144
):
145
super().__init__()
146
147
dims = cast_tuple(dim, stages)
148
depths = cast_tuple(depth, stages)
149
layer_heads = cast_tuple(heads, stages)
150
151
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
152
153
self.conv_embedding = nn.Sequential(
154
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
155
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
156
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
157
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
158
)
159
160
fmap_size = image_size // (2 ** 4)
161
layers = []
162
163
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
164
is_last = ind == (stages - 1)
165
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
166
167
if not is_last:
168
next_dim = dims[ind + 1]
169
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
170
fmap_size = ceil(fmap_size / 2)
171
172
self.backbone = nn.Sequential(*layers)
173
174
self.pool = nn.Sequential(
175
nn.AdaptiveAvgPool2d(1),
176
Rearrange('... () () -> ...')
177
)
178
179
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
180
self.mlp_head = nn.Linear(dim, num_classes)
181
182
def forward(self, img):
183
x = self.conv_embedding(img)
184
185
x = self.backbone(x)
186
187
x = self.pool(x)
188
189
out = self.mlp_head(x)
190
distill = self.distill_head(x)
191
192
if exists(distill):
193
return out, distill
194
195
return out
196
197