Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/vit.py
809 views
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
Mostly copy-paste from timm library.
16
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
"""
18
import math
19
from functools import partial
20
21
import torch
22
import torch.nn as nn
23
24
25
class DropPath(nn.Module):
26
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
27
"""
28
def __init__(self, drop_prob=None):
29
super(DropPath, self).__init__()
30
self.drop_prob = drop_prob
31
32
def forward(self, x):
33
return drop_path(x, self.drop_prob, self.training)
34
35
36
class Mlp(nn.Module):
37
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
38
super().__init__()
39
out_features = out_features or in_features
40
hidden_features = hidden_features or in_features
41
self.fc1 = nn.Linear(in_features, hidden_features)
42
self.act = act_layer()
43
self.fc2 = nn.Linear(hidden_features, out_features)
44
self.drop = nn.Dropout(drop)
45
46
def forward(self, x):
47
x = self.fc1(x)
48
x = self.act(x)
49
x = self.drop(x)
50
x = self.fc2(x)
51
x = self.drop(x)
52
return x
53
54
55
class Attention(nn.Module):
56
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
57
super().__init__()
58
self.num_heads = num_heads
59
head_dim = dim // num_heads
60
self.scale = qk_scale or head_dim ** -0.5
61
62
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
63
self.attn_drop = nn.Dropout(attn_drop)
64
self.proj = nn.Linear(dim, dim)
65
self.proj_drop = nn.Dropout(proj_drop)
66
67
def forward(self, x):
68
B, N, C = x.shape
69
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
70
q, k, v = qkv[0], qkv[1], qkv[2]
71
72
attn = (q @ k.transpose(-2, -1)) * self.scale
73
attn = attn.softmax(dim=-1)
74
attn = self.attn_drop(attn)
75
76
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
77
x = self.proj(x)
78
x = self.proj_drop(x)
79
return x, attn
80
81
82
class Block(nn.Module):
83
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
84
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
85
super().__init__()
86
self.norm1 = norm_layer(dim)
87
self.attn = Attention(
88
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
89
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
90
self.norm2 = norm_layer(dim)
91
mlp_hidden_dim = int(dim * mlp_ratio)
92
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
93
94
def forward(self, x, return_attention=False):
95
y, attn = self.attn(self.norm1(x))
96
if return_attention:
97
return attn
98
x = x + self.drop_path(y)
99
x = x + self.drop_path(self.mlp(self.norm2(x)))
100
return x
101
102
103
class PatchEmbed(nn.Module):
104
""" Image to Patch Embedding
105
"""
106
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
107
super().__init__()
108
num_patches = (img_size // patch_size) * (img_size // patch_size)
109
self.img_size = img_size
110
self.patch_size = patch_size
111
self.num_patches = num_patches
112
113
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
114
115
def forward(self, x):
116
B, C, H, W = x.shape
117
x = self.proj(x).flatten(2).transpose(1, 2)
118
return x
119
120
121
class VisionTransformer(nn.Module):
122
""" Vision Transformer """
123
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
124
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
125
drop_path_rate=0., num_last_blocks=4, norm_layer=nn.LayerNorm, **kwargs):
126
super().__init__()
127
self.num_features = self.embed_dim = embed_dim
128
self.num_last_blocks = num_last_blocks
129
130
self.patch_embed = PatchEmbed(
131
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
132
num_patches = self.patch_embed.num_patches
133
134
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
135
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
136
self.pos_drop = nn.Dropout(p=drop_rate)
137
138
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
139
self.blocks = nn.ModuleList([
140
Block(
141
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
142
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
143
for i in range(depth)])
144
self.norm = norm_layer(embed_dim)
145
146
# Classifier head
147
self.linear = nn.Linear(embed_dim*self.num_last_blocks, num_classes) if num_classes > 0 else nn.Identity()
148
149
trunc_normal_(self.pos_embed, std=.02)
150
trunc_normal_(self.cls_token, std=.02)
151
self.apply(self._init_weights)
152
153
def _init_weights(self, m):
154
if isinstance(m, nn.Linear):
155
trunc_normal_(m.weight, std=.02)
156
if isinstance(m, nn.Linear) and m.bias is not None:
157
nn.init.constant_(m.bias, 0)
158
elif isinstance(m, nn.LayerNorm):
159
nn.init.constant_(m.bias, 0)
160
nn.init.constant_(m.weight, 1.0)
161
162
def interpolate_pos_encoding(self, x, w, h):
163
npatch = x.shape[1] - 1
164
N = self.pos_embed.shape[1] - 1
165
if npatch == N and w == h:
166
return self.pos_embed
167
class_pos_embed = self.pos_embed[:, 0]
168
patch_pos_embed = self.pos_embed[:, 1:]
169
dim = x.shape[-1]
170
w0 = w // self.patch_embed.patch_size
171
h0 = h // self.patch_embed.patch_size
172
# we add a small number to avoid floating point error in the interpolation
173
# see discussion at https://github.com/facebookresearch/dino/issues/8
174
w0, h0 = w0 + 0.1, h0 + 0.1
175
patch_pos_embed = nn.functional.interpolate(
176
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
177
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
178
mode='bicubic',
179
)
180
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
181
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
182
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
183
184
def prepare_tokens(self, x):
185
B, nc, w, h = x.shape
186
x = self.patch_embed(x) # patch linear embedding
187
188
# add the [CLS] token to the embed patch tokens
189
cls_tokens = self.cls_token.expand(B, -1, -1)
190
x = torch.cat((cls_tokens, x), dim=1)
191
192
# add positional encoding to each token
193
x = x + self.interpolate_pos_encoding(x, w, h)
194
195
return self.pos_drop(x)
196
197
def get_logits(self, x):
198
x = x.view(x.size(0), -1)
199
return self.linear(x)
200
201
def get_last_selfattention(self, x):
202
x = self.prepare_tokens(x)
203
for i, blk in enumerate(self.blocks):
204
if i < len(self.blocks) - 1:
205
x = blk(x)
206
else:
207
# return attention of the last block
208
return blk(x, return_attention=True)
209
210
def get_intermediate_layers(self, x, n=1):
211
x = self.prepare_tokens(x)
212
# we return the output tokens from the `n` last blocks
213
output = []
214
for i, blk in enumerate(self.blocks):
215
x = blk(x)
216
if len(self.blocks) - i <= n:
217
output.append(self.norm(x))
218
return output
219
220
def forward(self, x):
221
intermediate_output = self.get_intermediate_layers(x, self.num_last_blocks)
222
embed = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
223
output = self.get_logits(embed)
224
return embed, output
225
226
227
def vit_tiny(patch_size=16, **kwargs):
228
model = VisionTransformer(
229
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
230
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
231
return model
232
233
234
def vit_small(patch_size=16, **kwargs):
235
model = VisionTransformer(
236
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
237
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
238
return model
239
240
241
def vit_base(patch_size=16, **kwargs):
242
model = VisionTransformer(
243
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
244
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
245
return model
246
247
248
class DINOHead(nn.Module):
249
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
250
super().__init__()
251
nlayers = max(nlayers, 1)
252
if nlayers == 1:
253
self.mlp = nn.Linear(in_dim, bottleneck_dim)
254
else:
255
layers = [nn.Linear(in_dim, hidden_dim)]
256
if use_bn:
257
layers.append(nn.BatchNorm1d(hidden_dim))
258
layers.append(nn.GELU())
259
for _ in range(nlayers - 2):
260
layers.append(nn.Linear(hidden_dim, hidden_dim))
261
if use_bn:
262
layers.append(nn.BatchNorm1d(hidden_dim))
263
layers.append(nn.GELU())
264
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
265
self.mlp = nn.Sequential(*layers)
266
self.apply(self._init_weights)
267
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
268
self.last_layer.weight_g.data.fill_(1)
269
if norm_last_layer:
270
self.last_layer.weight_g.requires_grad = False
271
272
def _init_weights(self, m):
273
if isinstance(m, nn.Linear):
274
trunc_normal_(m.weight, std=.02)
275
if isinstance(m, nn.Linear) and m.bias is not None:
276
nn.init.constant_(m.bias, 0)
277
278
def forward(self, x):
279
x = self.mlp(x)
280
x = nn.functional.normalize(x, dim=-1, p=2)
281
x = self.last_layer(x)
282
return x
283
284
285
def drop_path(x, drop_prob: float = 0., training: bool = False):
286
if drop_prob == 0. or not training:
287
return x
288
keep_prob = 1 - drop_prob
289
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
290
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
291
random_tensor.floor_() # binarize
292
output = x.div(keep_prob) * random_tensor
293
return output
294
295
296
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
297
# Cut & paste from PyTorch official master until it's in a few official releases - RW
298
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
299
def norm_cdf(x):
300
# Computes standard normal cumulative distribution function
301
return (1. + math.erf(x / math.sqrt(2.))) / 2.
302
303
with torch.no_grad():
304
# Values are generated by using a truncated uniform distribution and
305
# then using the inverse CDF for the normal distribution.
306
# Get upper and lower cdf values
307
l = norm_cdf((a - mean) / std)
308
u = norm_cdf((b - mean) / std)
309
310
# Uniformly fill tensor with values from [l, u], then translate to
311
# [2l-1, 2u-1].
312
tensor.uniform_(2 * l - 1, 2 * u - 1)
313
314
# Use inverse cdf transform for normal distribution to get truncated
315
# standard normal
316
tensor.erfinv_()
317
318
# Transform to proper mean, std
319
tensor.mul_(std * math.sqrt(2.))
320
tensor.add_(mean)
321
322
# Clamp to ensure it's in the proper range
323
tensor.clamp_(min=a, max=b)
324
return tensor
325
326
327
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
328
# type: (Tensor, float, float, float, float) -> Tensor
329
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
330
331