Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/normalized_vit.py
649 views
1
import torch
2
from torch import nn
3
from torch.nn import Module, ModuleList
4
import torch.nn.functional as F
5
import torch.nn.utils.parametrize as parametrize
6
7
from einops import rearrange, reduce
8
from einops.layers.torch import Rearrange
9
10
# functions
11
12
def exists(v):
13
return v is not None
14
15
def default(v, d):
16
return v if exists(v) else d
17
18
def pair(t):
19
return t if isinstance(t, tuple) else (t, t)
20
21
def divisible_by(numer, denom):
22
return (numer % denom) == 0
23
24
def l2norm(t, dim = -1):
25
return F.normalize(t, dim = dim, p = 2)
26
27
# for use with parametrize
28
29
class L2Norm(Module):
30
def __init__(self, dim = -1):
31
super().__init__()
32
self.dim = dim
33
34
def forward(self, t):
35
return l2norm(t, dim = self.dim)
36
37
class NormLinear(Module):
38
def __init__(
39
self,
40
dim,
41
dim_out,
42
norm_dim_in = True
43
):
44
super().__init__()
45
self.linear = nn.Linear(dim, dim_out, bias = False)
46
47
parametrize.register_parametrization(
48
self.linear,
49
'weight',
50
L2Norm(dim = -1 if norm_dim_in else 0)
51
)
52
53
@property
54
def weight(self):
55
return self.linear.weight
56
57
def forward(self, x):
58
return self.linear(x)
59
60
# attention and feedforward
61
62
class Attention(Module):
63
def __init__(
64
self,
65
dim,
66
*,
67
dim_head = 64,
68
heads = 8,
69
dropout = 0.
70
):
71
super().__init__()
72
dim_inner = dim_head * heads
73
self.to_q = NormLinear(dim, dim_inner)
74
self.to_k = NormLinear(dim, dim_inner)
75
self.to_v = NormLinear(dim, dim_inner)
76
77
self.dropout = dropout
78
79
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
80
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
81
82
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
83
self.merge_heads = Rearrange('b h n d -> b n (h d)')
84
85
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
86
87
def forward(
88
self,
89
x
90
):
91
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
92
93
q, k, v = map(self.split_heads, (q, k, v))
94
95
# query key rmsnorm
96
97
q, k = map(l2norm, (q, k))
98
99
q = q * self.q_scale
100
k = k * self.k_scale
101
102
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
103
104
out = F.scaled_dot_product_attention(
105
q, k, v,
106
dropout_p = self.dropout if self.training else 0.,
107
scale = 1.
108
)
109
110
out = self.merge_heads(out)
111
return self.to_out(out)
112
113
class FeedForward(Module):
114
def __init__(
115
self,
116
dim,
117
*,
118
dim_inner,
119
dropout = 0.
120
):
121
super().__init__()
122
dim_inner = int(dim_inner * 2 / 3)
123
124
self.dim = dim
125
self.dropout = nn.Dropout(dropout)
126
127
self.to_hidden = NormLinear(dim, dim_inner)
128
self.to_gate = NormLinear(dim, dim_inner)
129
130
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
131
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
132
133
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
134
135
def forward(self, x):
136
hidden, gate = self.to_hidden(x), self.to_gate(x)
137
138
hidden = hidden * self.hidden_scale
139
gate = gate * self.gate_scale * (self.dim ** 0.5)
140
141
hidden = F.silu(gate) * hidden
142
143
hidden = self.dropout(hidden)
144
return self.to_out(hidden)
145
146
# classes
147
148
class nViT(Module):
149
""" https://arxiv.org/abs/2410.01131 """
150
151
def __init__(
152
self,
153
*,
154
image_size,
155
patch_size,
156
num_classes,
157
dim,
158
depth,
159
heads,
160
mlp_dim,
161
dropout = 0.,
162
channels = 3,
163
dim_head = 64,
164
residual_lerp_scale_init = None
165
):
166
super().__init__()
167
image_height, image_width = pair(image_size)
168
169
# calculate patching related stuff
170
171
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
172
173
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
174
patch_dim = channels * (patch_size ** 2)
175
num_patches = patch_height_dim * patch_width_dim
176
177
self.channels = channels
178
self.patch_size = patch_size
179
180
self.to_patch_embedding = nn.Sequential(
181
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
182
NormLinear(patch_dim, dim, norm_dim_in = False),
183
)
184
185
self.abs_pos_emb = NormLinear(dim, num_patches)
186
187
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
188
189
# layers
190
191
self.dim = dim
192
self.scale = dim ** 0.5
193
194
self.layers = ModuleList([])
195
self.residual_lerp_scales = nn.ParameterList([])
196
197
for _ in range(depth):
198
self.layers.append(ModuleList([
199
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
200
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
201
]))
202
203
self.residual_lerp_scales.append(nn.ParameterList([
204
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
205
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
206
]))
207
208
self.logit_scale = nn.Parameter(torch.ones(num_classes))
209
210
self.to_pred = NormLinear(dim, num_classes)
211
212
@torch.no_grad()
213
def norm_weights_(self):
214
for module in self.modules():
215
if not isinstance(module, NormLinear):
216
continue
217
218
normed = module.weight
219
original = module.linear.parametrizations.weight.original
220
221
original.copy_(normed)
222
223
def forward(self, images):
224
device = images.device
225
226
tokens = self.to_patch_embedding(images)
227
228
seq_len = tokens.shape[-2]
229
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
230
231
tokens = l2norm(tokens + pos_emb)
232
233
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
234
235
attn_out = l2norm(attn(tokens))
236
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
237
238
ff_out = l2norm(ff(tokens))
239
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
240
241
pooled = reduce(tokens, 'b n d -> b d', 'mean')
242
243
logits = self.to_pred(pooled)
244
logits = logits * self.logit_scale * self.scale
245
246
return logits
247
248
# quick test
249
250
if __name__ == '__main__':
251
252
v = nViT(
253
image_size = 256,
254
patch_size = 16,
255
num_classes = 1000,
256
dim = 1024,
257
depth = 6,
258
heads = 8,
259
mlp_dim = 2048,
260
)
261
262
img = torch.randn(4, 3, 256, 256)
263
logits = v(img) # (4, 1000)
264
assert logits.shape == (4, 1000)
265
266