Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/learnable_memory_vit.py
649 views
1
import torch
2
from torch import nn
3
import torch.nn.functional as F
4
5
from einops import rearrange, repeat
6
from einops.layers.torch import Rearrange
7
8
# helpers
9
10
def exists(val):
11
return val is not None
12
13
def pair(t):
14
return t if isinstance(t, tuple) else (t, t)
15
16
# controlling freezing of layers
17
18
def set_module_requires_grad_(module, requires_grad):
19
for param in module.parameters():
20
param.requires_grad = requires_grad
21
22
def freeze_all_layers_(module):
23
set_module_requires_grad_(module, False)
24
25
def unfreeze_all_layers_(module):
26
set_module_requires_grad_(module, True)
27
28
# classes
29
30
class FeedForward(nn.Module):
31
def __init__(self, dim, hidden_dim, dropout = 0.):
32
super().__init__()
33
self.net = nn.Sequential(
34
nn.LayerNorm(dim),
35
nn.Linear(dim, hidden_dim),
36
nn.GELU(),
37
nn.Dropout(dropout),
38
nn.Linear(hidden_dim, dim),
39
nn.Dropout(dropout)
40
)
41
def forward(self, x):
42
return self.net(x)
43
44
class Attention(nn.Module):
45
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
46
super().__init__()
47
inner_dim = dim_head * heads
48
49
self.heads = heads
50
self.scale = dim_head ** -0.5
51
self.norm = nn.LayerNorm(dim)
52
53
self.attend = nn.Softmax(dim = -1)
54
self.dropout = nn.Dropout(dropout)
55
56
self.to_q = nn.Linear(dim, inner_dim, bias = False)
57
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
58
59
self.to_out = nn.Sequential(
60
nn.Linear(inner_dim, dim),
61
nn.Dropout(dropout)
62
)
63
64
def forward(self, x, attn_mask = None, memories = None):
65
x = self.norm(x)
66
67
x_kv = x # input for key / values projection
68
69
if exists(memories):
70
# add memories to key / values if it is passed in
71
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
72
x_kv = torch.cat((x_kv, memories), dim = 1)
73
74
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
75
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
76
77
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
78
79
if exists(attn_mask):
80
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
81
82
attn = self.attend(dots)
83
attn = self.dropout(attn)
84
85
out = torch.matmul(attn, v)
86
out = rearrange(out, 'b h n d -> b n (h d)')
87
return self.to_out(out)
88
89
class Transformer(nn.Module):
90
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
91
super().__init__()
92
self.layers = nn.ModuleList([])
93
for _ in range(depth):
94
self.layers.append(nn.ModuleList([
95
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
96
FeedForward(dim, mlp_dim, dropout = dropout)
97
]))
98
99
def forward(self, x, attn_mask = None, memories = None):
100
for ind, (attn, ff) in enumerate(self.layers):
101
layer_memories = memories[ind] if exists(memories) else None
102
103
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
104
x = ff(x) + x
105
return x
106
107
class ViT(nn.Module):
108
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
109
super().__init__()
110
image_height, image_width = pair(image_size)
111
patch_height, patch_width = pair(patch_size)
112
113
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
114
115
num_patches = (image_height // patch_height) * (image_width // patch_width)
116
patch_dim = channels * patch_height * patch_width
117
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
118
119
self.to_patch_embedding = nn.Sequential(
120
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
121
nn.LayerNorm(patch_dim),
122
nn.Linear(patch_dim, dim),
123
nn.LayerNorm(dim)
124
)
125
126
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
127
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
128
self.dropout = nn.Dropout(emb_dropout)
129
130
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
131
132
self.mlp_head = nn.Sequential(
133
nn.LayerNorm(dim),
134
nn.Linear(dim, num_classes)
135
)
136
137
def img_to_tokens(self, img):
138
x = self.to_patch_embedding(img)
139
140
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
141
x = torch.cat((cls_tokens, x), dim = 1)
142
143
x += self.pos_embedding
144
x = self.dropout(x)
145
return x
146
147
def forward(self, img):
148
x = self.img_to_tokens(img)
149
150
x = self.transformer(x)
151
152
cls_tokens = x[:, 0]
153
return self.mlp_head(cls_tokens)
154
155
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
156
157
class Adapter(nn.Module):
158
def __init__(
159
self,
160
*,
161
vit,
162
num_memories_per_layer = 10,
163
num_classes = 2,
164
):
165
super().__init__()
166
assert isinstance(vit, ViT)
167
168
# extract some model variables needed
169
170
dim = vit.cls_token.shape[-1]
171
layers = len(vit.transformer.layers)
172
num_patches = vit.pos_embedding.shape[-2]
173
174
self.vit = vit
175
176
# freeze ViT backbone - only memories will be finetuned
177
178
freeze_all_layers_(vit)
179
180
# learnable parameters
181
182
self.memory_cls_token = nn.Parameter(torch.randn(dim))
183
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
184
185
self.mlp_head = nn.Sequential(
186
nn.LayerNorm(dim),
187
nn.Linear(dim, num_classes)
188
)
189
190
# specialized attention mask to preserve the output of the original ViT
191
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
192
193
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
194
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
195
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
196
self.register_buffer('attn_mask', attn_mask)
197
198
def forward(self, img):
199
b = img.shape[0]
200
201
tokens = self.vit.img_to_tokens(img)
202
203
# add task specific memory tokens
204
205
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
206
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
207
208
# pass memories along with image tokens through transformer for attending
209
210
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
211
212
# extract memory CLS tokens
213
214
memory_cls_tokens = out[:, 0]
215
216
# pass through task specific adapter head
217
218
return self.mlp_head(memory_cls_tokens)
219
220