Path: blob/master/src/metrics/swin_transformer.py
809 views
# --------------------------------------------------------1"""2Swin Transformer3Copyright (c) 2021 Microsoft4Licensed under The MIT License [see LICENSE for details]5Written by Ze Liu67MIT License89Copyright (c) Microsoft Corporation.1011Permission is hereby granted, free of charge, to any person obtaining a copy12of this software and associated documentation files (the "Software"), to deal13in the Software without restriction, including without limitation the rights14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell15copies of the Software, and to permit persons to whom the Software is16furnished to do so, subject to the following conditions:1718The above copyright notice and this permission notice shall be included in all19copies or substantial portions of the Software.2021THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE27SOFTWARE28"""29# --------------------------------------------------------3031import torch32import torch.nn as nn33import torch.utils.checkpoint as checkpoint34from timm.models.layers import DropPath, to_2tuple, trunc_normal_353637class Mlp(nn.Module):38def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):39super().__init__()40out_features = out_features or in_features41hidden_features = hidden_features or in_features42self.fc1 = nn.Linear(in_features, hidden_features)43self.act = act_layer()44self.fc2 = nn.Linear(hidden_features, out_features)45self.drop = nn.Dropout(drop)4647def forward(self, x):48x = self.fc1(x)49x = self.act(x)50x = self.drop(x)51x = self.fc2(x)52x = self.drop(x)53return x545556def window_partition(x, window_size):57"""58Args:59x: (B, H, W, C)60window_size (int): window size61Returns:62windows: (num_windows*B, window_size, window_size, C)63"""64B, H, W, C = x.shape65x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)66windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)67return windows686970def window_reverse(windows, window_size, H, W):71"""72Args:73windows: (num_windows*B, window_size, window_size, C)74window_size (int): Window size75H (int): Height of image76W (int): Width of image77Returns:78x: (B, H, W, C)79"""80B = int(windows.shape[0] / (H * W / window_size / window_size))81x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)82x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)83return x848586class WindowAttention(nn.Module):87r""" Window based multi-head self attention (W-MSA) module with relative position bias.88It supports both of shifted and non-shifted window.89Args:90dim (int): Number of input channels.91window_size (tuple[int]): The height and width of the window.92num_heads (int): Number of attention heads.93qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True94qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set95attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.096proj_drop (float, optional): Dropout ratio of output. Default: 0.097"""9899def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):100101super().__init__()102self.dim = dim103self.window_size = window_size # Wh, Ww104self.num_heads = num_heads105head_dim = dim // num_heads106self.scale = qk_scale or head_dim ** -0.5107108# define a parameter table of relative position bias109self.relative_position_bias_table = nn.Parameter(110torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH111112# get pair-wise relative position index for each token inside the window113coords_h = torch.arange(self.window_size[0])114coords_w = torch.arange(self.window_size[1])115coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww116coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww117relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww118relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2119relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0120relative_coords[:, :, 1] += self.window_size[1] - 1121relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1122relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww123self.register_buffer("relative_position_index", relative_position_index)124125self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)126self.attn_drop = nn.Dropout(attn_drop)127self.proj = nn.Linear(dim, dim)128self.proj_drop = nn.Dropout(proj_drop)129130trunc_normal_(self.relative_position_bias_table, std=.02)131self.softmax = nn.Softmax(dim=-1)132133def forward(self, x, mask=None):134"""135Args:136x: input features with shape of (num_windows*B, N, C)137mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None138"""139B_, N, C = x.shape140qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)141q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)142143q = q * self.scale144attn = (q @ k.transpose(-2, -1))145146relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(147self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH148relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww149attn = attn + relative_position_bias.unsqueeze(0)150151if mask is not None:152nW = mask.shape[0]153attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)154attn = attn.view(-1, self.num_heads, N, N)155attn = self.softmax(attn)156else:157attn = self.softmax(attn)158159attn = self.attn_drop(attn)160161x = (attn @ v).transpose(1, 2).reshape(B_, N, C)162x = self.proj(x)163x = self.proj_drop(x)164return x165166def extra_repr(self) -> str:167return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'168169def flops(self, N):170# calculate flops for 1 window with token length of N171flops = 0172# qkv = self.qkv(x)173flops += N * self.dim * 3 * self.dim174# attn = (q @ k.transpose(-2, -1))175flops += self.num_heads * N * (self.dim // self.num_heads) * N176# x = (attn @ v)177flops += self.num_heads * N * N * (self.dim // self.num_heads)178# x = self.proj(x)179flops += N * self.dim * self.dim180return flops181182183class SwinTransformerBlock(nn.Module):184r""" Swin Transformer Block.185Args:186dim (int): Number of input channels.187input_resolution (tuple[int]): Input resulotion.188num_heads (int): Number of attention heads.189window_size (int): Window size.190shift_size (int): Shift size for SW-MSA.191mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.192qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True193qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.194drop (float, optional): Dropout rate. Default: 0.0195attn_drop (float, optional): Attention dropout rate. Default: 0.0196drop_path (float, optional): Stochastic depth rate. Default: 0.0197act_layer (nn.Module, optional): Activation layer. Default: nn.GELU198norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm199"""200201def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,202mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,203act_layer=nn.GELU, norm_layer=nn.LayerNorm):204super().__init__()205self.dim = dim206self.input_resolution = input_resolution207self.num_heads = num_heads208self.window_size = window_size209self.shift_size = shift_size210self.mlp_ratio = mlp_ratio211if min(self.input_resolution) <= self.window_size:212# if window size is larger than input resolution, we don't partition windows213self.shift_size = 0214self.window_size = min(self.input_resolution)215assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"216217self.norm1 = norm_layer(dim)218self.attn = WindowAttention(219dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,220qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)221222self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()223self.norm2 = norm_layer(dim)224mlp_hidden_dim = int(dim * mlp_ratio)225self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)226227if self.shift_size > 0:228# calculate attention mask for SW-MSA229H, W = self.input_resolution230img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1231h_slices = (slice(0, -self.window_size),232slice(-self.window_size, -self.shift_size),233slice(-self.shift_size, None))234w_slices = (slice(0, -self.window_size),235slice(-self.window_size, -self.shift_size),236slice(-self.shift_size, None))237cnt = 0238for h in h_slices:239for w in w_slices:240img_mask[:, h, w, :] = cnt241cnt += 1242243mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1244mask_windows = mask_windows.view(-1, self.window_size * self.window_size)245attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)246attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))247else:248attn_mask = None249250self.register_buffer("attn_mask", attn_mask)251252def forward(self, x):253H, W = self.input_resolution254B, L, C = x.shape255assert L == H * W, "input feature has wrong size"256257shortcut = x258x = self.norm1(x)259x = x.view(B, H, W, C)260261# cyclic shift262if self.shift_size > 0:263shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))264else:265shifted_x = x266267# partition windows268x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C269x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C270271# W-MSA/SW-MSA272attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C273274# merge windows275attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)276shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C277278# reverse cyclic shift279if self.shift_size > 0:280x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))281else:282x = shifted_x283x = x.view(B, H * W, C)284285# FFN286x = shortcut + self.drop_path(x)287x = x + self.drop_path(self.mlp(self.norm2(x)))288289return x290291def extra_repr(self) -> str:292return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \293f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"294295def flops(self):296flops = 0297H, W = self.input_resolution298# norm1299flops += self.dim * H * W300# W-MSA/SW-MSA301nW = H * W / self.window_size / self.window_size302flops += nW * self.attn.flops(self.window_size * self.window_size)303# mlp304flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio305# norm2306flops += self.dim * H * W307return flops308309310class PatchMerging(nn.Module):311r""" Patch Merging Layer.312Args:313input_resolution (tuple[int]): Resolution of input feature.314dim (int): Number of input channels.315norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm316"""317318def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):319super().__init__()320self.input_resolution = input_resolution321self.dim = dim322self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)323self.norm = norm_layer(4 * dim)324325def forward(self, x):326"""327x: B, H*W, C328"""329H, W = self.input_resolution330B, L, C = x.shape331assert L == H * W, "input feature has wrong size"332assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."333334x = x.view(B, H, W, C)335336x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C337x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C338x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C339x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C340x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C341x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C342343x = self.norm(x)344x = self.reduction(x)345346return x347348def extra_repr(self) -> str:349return f"input_resolution={self.input_resolution}, dim={self.dim}"350351def flops(self):352H, W = self.input_resolution353flops = H * W * self.dim354flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim355return flops356357358class BasicLayer(nn.Module):359""" A basic Swin Transformer layer for one stage.360Args:361dim (int): Number of input channels.362input_resolution (tuple[int]): Input resolution.363depth (int): Number of blocks.364num_heads (int): Number of attention heads.365window_size (int): Local window size.366mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.367qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True368qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.369drop (float, optional): Dropout rate. Default: 0.0370attn_drop (float, optional): Attention dropout rate. Default: 0.0371drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0372norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm373downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None374use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.375"""376377def __init__(self, dim, input_resolution, depth, num_heads, window_size,378mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,379drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):380381super().__init__()382self.dim = dim383self.input_resolution = input_resolution384self.depth = depth385self.use_checkpoint = use_checkpoint386387# build blocks388self.blocks = nn.ModuleList([389SwinTransformerBlock(dim=dim, input_resolution=input_resolution,390num_heads=num_heads, window_size=window_size,391shift_size=0 if (i % 2 == 0) else window_size // 2,392mlp_ratio=mlp_ratio,393qkv_bias=qkv_bias, qk_scale=qk_scale,394drop=drop, attn_drop=attn_drop,395drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,396norm_layer=norm_layer)397for i in range(depth)])398399# patch merging layer400if downsample is not None:401self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)402else:403self.downsample = None404405def forward(self, x):406for blk in self.blocks:407if self.use_checkpoint:408x = checkpoint.checkpoint(blk, x)409else:410x = blk(x)411if self.downsample is not None:412x = self.downsample(x)413return x414415def extra_repr(self) -> str:416return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"417418def flops(self):419flops = 0420for blk in self.blocks:421flops += blk.flops()422if self.downsample is not None:423flops += self.downsample.flops()424return flops425426427class PatchEmbed(nn.Module):428r""" Image to Patch Embedding429Args:430img_size (int): Image size. Default: 224.431patch_size (int): Patch token size. Default: 4.432in_chans (int): Number of input image channels. Default: 3.433embed_dim (int): Number of linear projection output channels. Default: 96.434norm_layer (nn.Module, optional): Normalization layer. Default: None435"""436437def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):438super().__init__()439img_size = to_2tuple(img_size)440patch_size = to_2tuple(patch_size)441patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]442self.img_size = img_size443self.patch_size = patch_size444self.patches_resolution = patches_resolution445self.num_patches = patches_resolution[0] * patches_resolution[1]446447self.in_chans = in_chans448self.embed_dim = embed_dim449450self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)451if norm_layer is not None:452self.norm = norm_layer(embed_dim)453else:454self.norm = None455456def forward(self, x):457B, C, H, W = x.shape458# FIXME look at relaxing size constraints459assert H == self.img_size[0] and W == self.img_size[1], \460f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."461x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C462if self.norm is not None:463x = self.norm(x)464return x465466def flops(self):467Ho, Wo = self.patches_resolution468flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])469if self.norm is not None:470flops += Ho * Wo * self.embed_dim471return flops472473474class SwinTransformer(nn.Module):475r""" Swin Transformer476A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -477https://arxiv.org/pdf/2103.14030478Args:479img_size (int | tuple(int)): Input image size. Default 224480patch_size (int | tuple(int)): Patch size. Default: 4481in_chans (int): Number of input image channels. Default: 3482num_classes (int): Number of classes for classification head. Default: 1000483embed_dim (int): Patch embedding dimension. Default: 96484depths (tuple(int)): Depth of each Swin Transformer layer.485num_heads (tuple(int)): Number of attention heads in different layers.486window_size (int): Window size. Default: 7487mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4488qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True489qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None490drop_rate (float): Dropout rate. Default: 0491attn_drop_rate (float): Attention dropout rate. Default: 0492drop_path_rate (float): Stochastic depth rate. Default: 0.1493norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.494ape (bool): If True, add absolute position embedding to the patch embedding. Default: False495patch_norm (bool): If True, add normalization after patch embedding. Default: True496use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False497"""498499def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,500embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],501window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,502drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2,503norm_layer=nn.LayerNorm, ape=False, patch_norm=True,504use_checkpoint=False):505super().__init__()506507self.num_classes = num_classes508self.num_layers = len(depths)509self.embed_dim = embed_dim510self.ape = ape511self.patch_norm = patch_norm512self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))513self.mlp_ratio = mlp_ratio514515# split image into non-overlapping patches516self.patch_embed = PatchEmbed(517img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,518norm_layer=norm_layer if self.patch_norm else None)519num_patches = self.patch_embed.num_patches520patches_resolution = self.patch_embed.patches_resolution521self.patches_resolution = patches_resolution522523# absolute position embedding524if self.ape:525self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))526trunc_normal_(self.absolute_pos_embed, std=.02)527528self.pos_drop = nn.Dropout(p=drop_rate)529530# stochastic depth531dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule532533# build layers534self.layers = nn.ModuleList()535for i_layer in range(self.num_layers):536layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),537input_resolution=(patches_resolution[0] // (2 ** i_layer),538patches_resolution[1] // (2 ** i_layer)),539depth=depths[i_layer],540num_heads=num_heads[i_layer],541window_size=window_size,542mlp_ratio=self.mlp_ratio,543qkv_bias=qkv_bias, qk_scale=qk_scale,544drop=drop_rate, attn_drop=attn_drop_rate,545drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],546norm_layer=norm_layer,547downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,548use_checkpoint=use_checkpoint)549self.layers.append(layer)550551self.norm = norm_layer(self.num_features)552self.avgpool = nn.AdaptiveAvgPool1d(1)553self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()554555self.apply(self._init_weights)556557def _init_weights(self, m):558if isinstance(m, nn.Linear):559trunc_normal_(m.weight, std=.02)560if isinstance(m, nn.Linear) and m.bias is not None:561nn.init.constant_(m.bias, 0)562elif isinstance(m, nn.LayerNorm):563nn.init.constant_(m.bias, 0)564nn.init.constant_(m.weight, 1.0)565566@torch.jit.ignore567def no_weight_decay(self):568return {'absolute_pos_embed'}569570@torch.jit.ignore571def no_weight_decay_keywords(self):572return {'relative_position_bias_table'}573574def forward_features(self, x):575x = self.patch_embed(x)576if self.ape:577x = x + self.absolute_pos_embed578x = self.pos_drop(x)579580for layer in self.layers:581x = layer(x)582583x = self.norm(x) # B L C584x = self.avgpool(x.transpose(1, 2)) # B C 1585x = torch.flatten(x, 1)586return x587588def forward(self, x):589h = self.forward_features(x)590x = self.head(h)591return h, x592593def flops(self):594flops = 0595flops += self.patch_embed.flops()596for i, layer in enumerate(self.layers):597flops += layer.flops()598flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)599flops += self.num_features * self.num_classes600return flops601602603