Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/parallel_vit.py
649 views
1
import torch
2
from torch import nn
3
4
from einops import rearrange, repeat
5
from einops.layers.torch import Rearrange
6
7
# helpers
8
9
def pair(t):
10
return t if isinstance(t, tuple) else (t, t)
11
12
# classes
13
14
class Parallel(nn.Module):
15
def __init__(self, *fns):
16
super().__init__()
17
self.fns = nn.ModuleList(fns)
18
19
def forward(self, x):
20
return sum([fn(x) for fn in self.fns])
21
22
class FeedForward(nn.Module):
23
def __init__(self, dim, hidden_dim, dropout = 0.):
24
super().__init__()
25
self.net = nn.Sequential(
26
nn.LayerNorm(dim),
27
nn.Linear(dim, hidden_dim),
28
nn.GELU(),
29
nn.Dropout(dropout),
30
nn.Linear(hidden_dim, dim),
31
nn.Dropout(dropout)
32
)
33
def forward(self, x):
34
return self.net(x)
35
36
class Attention(nn.Module):
37
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38
super().__init__()
39
inner_dim = dim_head * heads
40
project_out = not (heads == 1 and dim_head == dim)
41
42
self.heads = heads
43
self.scale = dim_head ** -0.5
44
45
self.norm = nn.LayerNorm(dim)
46
self.attend = nn.Softmax(dim = -1)
47
self.dropout = nn.Dropout(dropout)
48
49
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
50
51
self.to_out = nn.Sequential(
52
nn.Linear(inner_dim, dim),
53
nn.Dropout(dropout)
54
) if project_out else nn.Identity()
55
56
def forward(self, x):
57
x = self.norm(x)
58
qkv = self.to_qkv(x).chunk(3, dim = -1)
59
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
60
61
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
62
63
attn = self.attend(dots)
64
attn = self.dropout(attn)
65
66
out = torch.matmul(attn, v)
67
out = rearrange(out, 'b h n d -> b n (h d)')
68
return self.to_out(out)
69
70
class Transformer(nn.Module):
71
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
72
super().__init__()
73
self.layers = nn.ModuleList([])
74
75
attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
76
ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)
77
78
for _ in range(depth):
79
self.layers.append(nn.ModuleList([
80
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
81
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
82
]))
83
84
def forward(self, x):
85
for attns, ffs in self.layers:
86
x = attns(x) + x
87
x = ffs(x) + x
88
return x
89
90
class ViT(nn.Module):
91
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
92
super().__init__()
93
image_height, image_width = pair(image_size)
94
patch_height, patch_width = pair(patch_size)
95
96
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
97
98
num_patches = (image_height // patch_height) * (image_width // patch_width)
99
patch_dim = channels * patch_height * patch_width
100
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
101
102
self.to_patch_embedding = nn.Sequential(
103
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
104
nn.Linear(patch_dim, dim),
105
)
106
107
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
108
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
109
self.dropout = nn.Dropout(emb_dropout)
110
111
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
112
113
self.pool = pool
114
self.to_latent = nn.Identity()
115
116
self.mlp_head = nn.Sequential(
117
nn.LayerNorm(dim),
118
nn.Linear(dim, num_classes)
119
)
120
121
def forward(self, img):
122
x = self.to_patch_embedding(img)
123
b, n, _ = x.shape
124
125
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
126
x = torch.cat((cls_tokens, x), dim=1)
127
x += self.pos_embedding[:, :(n + 1)]
128
x = self.dropout(x)
129
130
x = self.transformer(x)
131
132
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
133
134
x = self.to_latent(x)
135
return self.mlp_head(x)
136
137