Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/metrics/swin_transformer.py
809 views
1
# --------------------------------------------------------
2
"""
3
Swin Transformer
4
Copyright (c) 2021 Microsoft
5
Licensed under The MIT License [see LICENSE for details]
6
Written by Ze Liu
7
8
MIT License
9
10
Copyright (c) Microsoft Corporation.
11
12
Permission is hereby granted, free of charge, to any person obtaining a copy
13
of this software and associated documentation files (the "Software"), to deal
14
in the Software without restriction, including without limitation the rights
15
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
copies of the Software, and to permit persons to whom the Software is
17
furnished to do so, subject to the following conditions:
18
19
The above copyright notice and this permission notice shall be included in all
20
copies or substantial portions of the Software.
21
22
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
SOFTWARE
29
"""
30
# --------------------------------------------------------
31
32
import torch
33
import torch.nn as nn
34
import torch.utils.checkpoint as checkpoint
35
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
36
37
38
class Mlp(nn.Module):
39
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
super().__init__()
41
out_features = out_features or in_features
42
hidden_features = hidden_features or in_features
43
self.fc1 = nn.Linear(in_features, hidden_features)
44
self.act = act_layer()
45
self.fc2 = nn.Linear(hidden_features, out_features)
46
self.drop = nn.Dropout(drop)
47
48
def forward(self, x):
49
x = self.fc1(x)
50
x = self.act(x)
51
x = self.drop(x)
52
x = self.fc2(x)
53
x = self.drop(x)
54
return x
55
56
57
def window_partition(x, window_size):
58
"""
59
Args:
60
x: (B, H, W, C)
61
window_size (int): window size
62
Returns:
63
windows: (num_windows*B, window_size, window_size, C)
64
"""
65
B, H, W, C = x.shape
66
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
67
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
68
return windows
69
70
71
def window_reverse(windows, window_size, H, W):
72
"""
73
Args:
74
windows: (num_windows*B, window_size, window_size, C)
75
window_size (int): Window size
76
H (int): Height of image
77
W (int): Width of image
78
Returns:
79
x: (B, H, W, C)
80
"""
81
B = int(windows.shape[0] / (H * W / window_size / window_size))
82
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
83
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
84
return x
85
86
87
class WindowAttention(nn.Module):
88
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
89
It supports both of shifted and non-shifted window.
90
Args:
91
dim (int): Number of input channels.
92
window_size (tuple[int]): The height and width of the window.
93
num_heads (int): Number of attention heads.
94
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
95
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
96
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
97
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
98
"""
99
100
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
101
102
super().__init__()
103
self.dim = dim
104
self.window_size = window_size # Wh, Ww
105
self.num_heads = num_heads
106
head_dim = dim // num_heads
107
self.scale = qk_scale or head_dim ** -0.5
108
109
# define a parameter table of relative position bias
110
self.relative_position_bias_table = nn.Parameter(
111
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
112
113
# get pair-wise relative position index for each token inside the window
114
coords_h = torch.arange(self.window_size[0])
115
coords_w = torch.arange(self.window_size[1])
116
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
117
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
118
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
119
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
120
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
121
relative_coords[:, :, 1] += self.window_size[1] - 1
122
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
123
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
124
self.register_buffer("relative_position_index", relative_position_index)
125
126
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
self.attn_drop = nn.Dropout(attn_drop)
128
self.proj = nn.Linear(dim, dim)
129
self.proj_drop = nn.Dropout(proj_drop)
130
131
trunc_normal_(self.relative_position_bias_table, std=.02)
132
self.softmax = nn.Softmax(dim=-1)
133
134
def forward(self, x, mask=None):
135
"""
136
Args:
137
x: input features with shape of (num_windows*B, N, C)
138
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
139
"""
140
B_, N, C = x.shape
141
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
142
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
143
144
q = q * self.scale
145
attn = (q @ k.transpose(-2, -1))
146
147
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
148
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
149
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
150
attn = attn + relative_position_bias.unsqueeze(0)
151
152
if mask is not None:
153
nW = mask.shape[0]
154
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
155
attn = attn.view(-1, self.num_heads, N, N)
156
attn = self.softmax(attn)
157
else:
158
attn = self.softmax(attn)
159
160
attn = self.attn_drop(attn)
161
162
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
163
x = self.proj(x)
164
x = self.proj_drop(x)
165
return x
166
167
def extra_repr(self) -> str:
168
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
169
170
def flops(self, N):
171
# calculate flops for 1 window with token length of N
172
flops = 0
173
# qkv = self.qkv(x)
174
flops += N * self.dim * 3 * self.dim
175
# attn = (q @ k.transpose(-2, -1))
176
flops += self.num_heads * N * (self.dim // self.num_heads) * N
177
# x = (attn @ v)
178
flops += self.num_heads * N * N * (self.dim // self.num_heads)
179
# x = self.proj(x)
180
flops += N * self.dim * self.dim
181
return flops
182
183
184
class SwinTransformerBlock(nn.Module):
185
r""" Swin Transformer Block.
186
Args:
187
dim (int): Number of input channels.
188
input_resolution (tuple[int]): Input resulotion.
189
num_heads (int): Number of attention heads.
190
window_size (int): Window size.
191
shift_size (int): Shift size for SW-MSA.
192
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
193
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
194
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
195
drop (float, optional): Dropout rate. Default: 0.0
196
attn_drop (float, optional): Attention dropout rate. Default: 0.0
197
drop_path (float, optional): Stochastic depth rate. Default: 0.0
198
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
199
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
200
"""
201
202
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
203
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
204
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
205
super().__init__()
206
self.dim = dim
207
self.input_resolution = input_resolution
208
self.num_heads = num_heads
209
self.window_size = window_size
210
self.shift_size = shift_size
211
self.mlp_ratio = mlp_ratio
212
if min(self.input_resolution) <= self.window_size:
213
# if window size is larger than input resolution, we don't partition windows
214
self.shift_size = 0
215
self.window_size = min(self.input_resolution)
216
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
217
218
self.norm1 = norm_layer(dim)
219
self.attn = WindowAttention(
220
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
221
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
222
223
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
224
self.norm2 = norm_layer(dim)
225
mlp_hidden_dim = int(dim * mlp_ratio)
226
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
227
228
if self.shift_size > 0:
229
# calculate attention mask for SW-MSA
230
H, W = self.input_resolution
231
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
232
h_slices = (slice(0, -self.window_size),
233
slice(-self.window_size, -self.shift_size),
234
slice(-self.shift_size, None))
235
w_slices = (slice(0, -self.window_size),
236
slice(-self.window_size, -self.shift_size),
237
slice(-self.shift_size, None))
238
cnt = 0
239
for h in h_slices:
240
for w in w_slices:
241
img_mask[:, h, w, :] = cnt
242
cnt += 1
243
244
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
245
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
246
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
247
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
248
else:
249
attn_mask = None
250
251
self.register_buffer("attn_mask", attn_mask)
252
253
def forward(self, x):
254
H, W = self.input_resolution
255
B, L, C = x.shape
256
assert L == H * W, "input feature has wrong size"
257
258
shortcut = x
259
x = self.norm1(x)
260
x = x.view(B, H, W, C)
261
262
# cyclic shift
263
if self.shift_size > 0:
264
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
265
else:
266
shifted_x = x
267
268
# partition windows
269
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
270
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
271
272
# W-MSA/SW-MSA
273
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
274
275
# merge windows
276
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
277
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
278
279
# reverse cyclic shift
280
if self.shift_size > 0:
281
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
282
else:
283
x = shifted_x
284
x = x.view(B, H * W, C)
285
286
# FFN
287
x = shortcut + self.drop_path(x)
288
x = x + self.drop_path(self.mlp(self.norm2(x)))
289
290
return x
291
292
def extra_repr(self) -> str:
293
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
294
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
295
296
def flops(self):
297
flops = 0
298
H, W = self.input_resolution
299
# norm1
300
flops += self.dim * H * W
301
# W-MSA/SW-MSA
302
nW = H * W / self.window_size / self.window_size
303
flops += nW * self.attn.flops(self.window_size * self.window_size)
304
# mlp
305
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
306
# norm2
307
flops += self.dim * H * W
308
return flops
309
310
311
class PatchMerging(nn.Module):
312
r""" Patch Merging Layer.
313
Args:
314
input_resolution (tuple[int]): Resolution of input feature.
315
dim (int): Number of input channels.
316
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
317
"""
318
319
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
320
super().__init__()
321
self.input_resolution = input_resolution
322
self.dim = dim
323
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
324
self.norm = norm_layer(4 * dim)
325
326
def forward(self, x):
327
"""
328
x: B, H*W, C
329
"""
330
H, W = self.input_resolution
331
B, L, C = x.shape
332
assert L == H * W, "input feature has wrong size"
333
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
334
335
x = x.view(B, H, W, C)
336
337
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
338
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
339
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
340
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
341
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
342
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
343
344
x = self.norm(x)
345
x = self.reduction(x)
346
347
return x
348
349
def extra_repr(self) -> str:
350
return f"input_resolution={self.input_resolution}, dim={self.dim}"
351
352
def flops(self):
353
H, W = self.input_resolution
354
flops = H * W * self.dim
355
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
356
return flops
357
358
359
class BasicLayer(nn.Module):
360
""" A basic Swin Transformer layer for one stage.
361
Args:
362
dim (int): Number of input channels.
363
input_resolution (tuple[int]): Input resolution.
364
depth (int): Number of blocks.
365
num_heads (int): Number of attention heads.
366
window_size (int): Local window size.
367
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
368
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
369
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
370
drop (float, optional): Dropout rate. Default: 0.0
371
attn_drop (float, optional): Attention dropout rate. Default: 0.0
372
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
373
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
374
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
375
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
376
"""
377
378
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
379
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
380
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
381
382
super().__init__()
383
self.dim = dim
384
self.input_resolution = input_resolution
385
self.depth = depth
386
self.use_checkpoint = use_checkpoint
387
388
# build blocks
389
self.blocks = nn.ModuleList([
390
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
391
num_heads=num_heads, window_size=window_size,
392
shift_size=0 if (i % 2 == 0) else window_size // 2,
393
mlp_ratio=mlp_ratio,
394
qkv_bias=qkv_bias, qk_scale=qk_scale,
395
drop=drop, attn_drop=attn_drop,
396
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
397
norm_layer=norm_layer)
398
for i in range(depth)])
399
400
# patch merging layer
401
if downsample is not None:
402
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
403
else:
404
self.downsample = None
405
406
def forward(self, x):
407
for blk in self.blocks:
408
if self.use_checkpoint:
409
x = checkpoint.checkpoint(blk, x)
410
else:
411
x = blk(x)
412
if self.downsample is not None:
413
x = self.downsample(x)
414
return x
415
416
def extra_repr(self) -> str:
417
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
418
419
def flops(self):
420
flops = 0
421
for blk in self.blocks:
422
flops += blk.flops()
423
if self.downsample is not None:
424
flops += self.downsample.flops()
425
return flops
426
427
428
class PatchEmbed(nn.Module):
429
r""" Image to Patch Embedding
430
Args:
431
img_size (int): Image size. Default: 224.
432
patch_size (int): Patch token size. Default: 4.
433
in_chans (int): Number of input image channels. Default: 3.
434
embed_dim (int): Number of linear projection output channels. Default: 96.
435
norm_layer (nn.Module, optional): Normalization layer. Default: None
436
"""
437
438
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
439
super().__init__()
440
img_size = to_2tuple(img_size)
441
patch_size = to_2tuple(patch_size)
442
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
443
self.img_size = img_size
444
self.patch_size = patch_size
445
self.patches_resolution = patches_resolution
446
self.num_patches = patches_resolution[0] * patches_resolution[1]
447
448
self.in_chans = in_chans
449
self.embed_dim = embed_dim
450
451
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
452
if norm_layer is not None:
453
self.norm = norm_layer(embed_dim)
454
else:
455
self.norm = None
456
457
def forward(self, x):
458
B, C, H, W = x.shape
459
# FIXME look at relaxing size constraints
460
assert H == self.img_size[0] and W == self.img_size[1], \
461
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
462
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
463
if self.norm is not None:
464
x = self.norm(x)
465
return x
466
467
def flops(self):
468
Ho, Wo = self.patches_resolution
469
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
470
if self.norm is not None:
471
flops += Ho * Wo * self.embed_dim
472
return flops
473
474
475
class SwinTransformer(nn.Module):
476
r""" Swin Transformer
477
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
478
https://arxiv.org/pdf/2103.14030
479
Args:
480
img_size (int | tuple(int)): Input image size. Default 224
481
patch_size (int | tuple(int)): Patch size. Default: 4
482
in_chans (int): Number of input image channels. Default: 3
483
num_classes (int): Number of classes for classification head. Default: 1000
484
embed_dim (int): Patch embedding dimension. Default: 96
485
depths (tuple(int)): Depth of each Swin Transformer layer.
486
num_heads (tuple(int)): Number of attention heads in different layers.
487
window_size (int): Window size. Default: 7
488
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
489
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
490
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
491
drop_rate (float): Dropout rate. Default: 0
492
attn_drop_rate (float): Attention dropout rate. Default: 0
493
drop_path_rate (float): Stochastic depth rate. Default: 0.1
494
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
495
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
496
patch_norm (bool): If True, add normalization after patch embedding. Default: True
497
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
498
"""
499
500
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
501
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
502
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
503
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,
504
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
505
use_checkpoint=False):
506
super().__init__()
507
508
self.num_classes = num_classes
509
self.num_layers = len(depths)
510
self.embed_dim = embed_dim
511
self.ape = ape
512
self.patch_norm = patch_norm
513
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
514
self.mlp_ratio = mlp_ratio
515
516
# split image into non-overlapping patches
517
self.patch_embed = PatchEmbed(
518
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
519
norm_layer=norm_layer if self.patch_norm else None)
520
num_patches = self.patch_embed.num_patches
521
patches_resolution = self.patch_embed.patches_resolution
522
self.patches_resolution = patches_resolution
523
524
# absolute position embedding
525
if self.ape:
526
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
527
trunc_normal_(self.absolute_pos_embed, std=.02)
528
529
self.pos_drop = nn.Dropout(p=drop_rate)
530
531
# stochastic depth
532
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
533
534
# build layers
535
self.layers = nn.ModuleList()
536
for i_layer in range(self.num_layers):
537
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
538
input_resolution=(patches_resolution[0] // (2 ** i_layer),
539
patches_resolution[1] // (2 ** i_layer)),
540
depth=depths[i_layer],
541
num_heads=num_heads[i_layer],
542
window_size=window_size,
543
mlp_ratio=self.mlp_ratio,
544
qkv_bias=qkv_bias, qk_scale=qk_scale,
545
drop=drop_rate, attn_drop=attn_drop_rate,
546
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
547
norm_layer=norm_layer,
548
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
549
use_checkpoint=use_checkpoint)
550
self.layers.append(layer)
551
552
self.norm = norm_layer(self.num_features)
553
self.avgpool = nn.AdaptiveAvgPool1d(1)
554
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
555
556
self.apply(self._init_weights)
557
558
def _init_weights(self, m):
559
if isinstance(m, nn.Linear):
560
trunc_normal_(m.weight, std=.02)
561
if isinstance(m, nn.Linear) and m.bias is not None:
562
nn.init.constant_(m.bias, 0)
563
elif isinstance(m, nn.LayerNorm):
564
nn.init.constant_(m.bias, 0)
565
nn.init.constant_(m.weight, 1.0)
566
567
@torch.jit.ignore
568
def no_weight_decay(self):
569
return {'absolute_pos_embed'}
570
571
@torch.jit.ignore
572
def no_weight_decay_keywords(self):
573
return {'relative_position_bias_table'}
574
575
def forward_features(self, x):
576
x = self.patch_embed(x)
577
if self.ape:
578
x = x + self.absolute_pos_embed
579
x = self.pos_drop(x)
580
581
for layer in self.layers:
582
x = layer(x)
583
584
x = self.norm(x) # B L C
585
x = self.avgpool(x.transpose(1, 2)) # B C 1
586
x = torch.flatten(x, 1)
587
return x
588
589
def forward(self, x):
590
h = self.forward_features(x)
591
x = self.head(h)
592
return h, x
593
594
def flops(self):
595
flops = 0
596
flops += self.patch_embed.flops()
597
for i, layer in enumerate(self.layers):
598
flops += layer.flops()
599
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
600
flops += self.num_features * self.num_classes
601
return flops
602
603