Path: blob/main/vit_pytorch/learnable_memory_vit.py
649 views
import torch1from torch import nn2import torch.nn.functional as F34from einops import rearrange, repeat5from einops.layers.torch import Rearrange67# helpers89def exists(val):10return val is not None1112def pair(t):13return t if isinstance(t, tuple) else (t, t)1415# controlling freezing of layers1617def set_module_requires_grad_(module, requires_grad):18for param in module.parameters():19param.requires_grad = requires_grad2021def freeze_all_layers_(module):22set_module_requires_grad_(module, False)2324def unfreeze_all_layers_(module):25set_module_requires_grad_(module, True)2627# classes2829class FeedForward(nn.Module):30def __init__(self, dim, hidden_dim, dropout = 0.):31super().__init__()32self.net = nn.Sequential(33nn.LayerNorm(dim),34nn.Linear(dim, hidden_dim),35nn.GELU(),36nn.Dropout(dropout),37nn.Linear(hidden_dim, dim),38nn.Dropout(dropout)39)40def forward(self, x):41return self.net(x)4243class Attention(nn.Module):44def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):45super().__init__()46inner_dim = dim_head * heads4748self.heads = heads49self.scale = dim_head ** -0.550self.norm = nn.LayerNorm(dim)5152self.attend = nn.Softmax(dim = -1)53self.dropout = nn.Dropout(dropout)5455self.to_q = nn.Linear(dim, inner_dim, bias = False)56self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)5758self.to_out = nn.Sequential(59nn.Linear(inner_dim, dim),60nn.Dropout(dropout)61)6263def forward(self, x, attn_mask = None, memories = None):64x = self.norm(x)6566x_kv = x # input for key / values projection6768if exists(memories):69# add memories to key / values if it is passed in70memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories71x_kv = torch.cat((x_kv, memories), dim = 1)7273qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))74q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)7576dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale7778if exists(attn_mask):79dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)8081attn = self.attend(dots)82attn = self.dropout(attn)8384out = torch.matmul(attn, v)85out = rearrange(out, 'b h n d -> b n (h d)')86return self.to_out(out)8788class Transformer(nn.Module):89def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):90super().__init__()91self.layers = nn.ModuleList([])92for _ in range(depth):93self.layers.append(nn.ModuleList([94Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),95FeedForward(dim, mlp_dim, dropout = dropout)96]))9798def forward(self, x, attn_mask = None, memories = None):99for ind, (attn, ff) in enumerate(self.layers):100layer_memories = memories[ind] if exists(memories) else None101102x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x103x = ff(x) + x104return x105106class ViT(nn.Module):107def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):108super().__init__()109image_height, image_width = pair(image_size)110patch_height, patch_width = pair(patch_size)111112assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'113114num_patches = (image_height // patch_height) * (image_width // patch_width)115patch_dim = channels * patch_height * patch_width116assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'117118self.to_patch_embedding = nn.Sequential(119Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),120nn.LayerNorm(patch_dim),121nn.Linear(patch_dim, dim),122nn.LayerNorm(dim)123)124125self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))126self.cls_token = nn.Parameter(torch.randn(1, 1, dim))127self.dropout = nn.Dropout(emb_dropout)128129self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)130131self.mlp_head = nn.Sequential(132nn.LayerNorm(dim),133nn.Linear(dim, num_classes)134)135136def img_to_tokens(self, img):137x = self.to_patch_embedding(img)138139cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])140x = torch.cat((cls_tokens, x), dim = 1)141142x += self.pos_embedding143x = self.dropout(x)144return x145146def forward(self, img):147x = self.img_to_tokens(img)148149x = self.transformer(x)150151cls_tokens = x[:, 0]152return self.mlp_head(cls_tokens)153154# adapter with learnable memories per layer, memory CLS token, and learnable adapter head155156class Adapter(nn.Module):157def __init__(158self,159*,160vit,161num_memories_per_layer = 10,162num_classes = 2,163):164super().__init__()165assert isinstance(vit, ViT)166167# extract some model variables needed168169dim = vit.cls_token.shape[-1]170layers = len(vit.transformer.layers)171num_patches = vit.pos_embedding.shape[-2]172173self.vit = vit174175# freeze ViT backbone - only memories will be finetuned176177freeze_all_layers_(vit)178179# learnable parameters180181self.memory_cls_token = nn.Parameter(torch.randn(dim))182self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))183184self.mlp_head = nn.Sequential(185nn.LayerNorm(dim),186nn.Linear(dim, num_classes)187)188189# specialized attention mask to preserve the output of the original ViT190# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa191192attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)193attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer194attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything195self.register_buffer('attn_mask', attn_mask)196197def forward(self, img):198b = img.shape[0]199200tokens = self.vit.img_to_tokens(img)201202# add task specific memory tokens203204memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)205tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)206207# pass memories along with image tokens through transformer for attending208209out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)210211# extract memory CLS tokens212213memory_cls_tokens = out[:, 0]214215# pass through task specific adapter head216217return self.mlp_head(memory_cls_tokens)218219220