Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/pit.py
649 views
1
from math import sqrt
2
3
import torch
4
from torch import nn, einsum
5
import torch.nn.functional as F
6
7
from einops import rearrange, repeat
8
from einops.layers.torch import Rearrange
9
10
# helpers
11
12
def cast_tuple(val, num):
13
return val if isinstance(val, tuple) else (val,) * num
14
15
def conv_output_size(image_size, kernel_size, stride, padding = 0):
16
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
17
18
# classes
19
20
class FeedForward(nn.Module):
21
def __init__(self, dim, hidden_dim, dropout = 0.):
22
super().__init__()
23
self.net = nn.Sequential(
24
nn.LayerNorm(dim),
25
nn.Linear(dim, hidden_dim),
26
nn.GELU(),
27
nn.Dropout(dropout),
28
nn.Linear(hidden_dim, dim),
29
nn.Dropout(dropout)
30
)
31
def forward(self, x):
32
return self.net(x)
33
34
class Attention(nn.Module):
35
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
36
super().__init__()
37
inner_dim = dim_head * heads
38
project_out = not (heads == 1 and dim_head == dim)
39
40
self.heads = heads
41
self.scale = dim_head ** -0.5
42
43
self.norm = nn.LayerNorm(dim)
44
self.attend = nn.Softmax(dim = -1)
45
self.dropout = nn.Dropout(dropout)
46
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
47
48
self.to_out = nn.Sequential(
49
nn.Linear(inner_dim, dim),
50
nn.Dropout(dropout)
51
) if project_out else nn.Identity()
52
53
def forward(self, x):
54
b, n, _, h = *x.shape, self.heads
55
56
x = self.norm(x)
57
qkv = self.to_qkv(x).chunk(3, dim = -1)
58
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
59
60
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
61
62
attn = self.attend(dots)
63
attn = self.dropout(attn)
64
65
out = einsum('b h i j, b h j d -> b h i d', attn, v)
66
out = rearrange(out, 'b h n d -> b n (h d)')
67
return self.to_out(out)
68
69
class Transformer(nn.Module):
70
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
71
super().__init__()
72
self.layers = nn.ModuleList([])
73
for _ in range(depth):
74
self.layers.append(nn.ModuleList([
75
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
76
FeedForward(dim, mlp_dim, dropout = dropout)
77
]))
78
def forward(self, x):
79
for attn, ff in self.layers:
80
x = attn(x) + x
81
x = ff(x) + x
82
return x
83
84
# depthwise convolution, for pooling
85
86
class DepthWiseConv2d(nn.Module):
87
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
88
super().__init__()
89
self.net = nn.Sequential(
90
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
91
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
92
)
93
def forward(self, x):
94
return self.net(x)
95
96
# pooling layer
97
98
class Pool(nn.Module):
99
def __init__(self, dim):
100
super().__init__()
101
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
102
self.cls_ff = nn.Linear(dim, dim * 2)
103
104
def forward(self, x):
105
cls_token, tokens = x[:, :1], x[:, 1:]
106
107
cls_token = self.cls_ff(cls_token)
108
109
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
110
tokens = self.downsample(tokens)
111
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
112
113
return torch.cat((cls_token, tokens), dim = 1)
114
115
# main class
116
117
class PiT(nn.Module):
118
def __init__(
119
self,
120
*,
121
image_size,
122
patch_size,
123
num_classes,
124
dim,
125
depth,
126
heads,
127
mlp_dim,
128
dim_head = 64,
129
dropout = 0.,
130
emb_dropout = 0.,
131
channels = 3
132
):
133
super().__init__()
134
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
135
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
136
heads = cast_tuple(heads, len(depth))
137
138
patch_dim = channels * patch_size ** 2
139
140
self.to_patch_embedding = nn.Sequential(
141
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
142
Rearrange('b c n -> b n c'),
143
nn.Linear(patch_dim, dim)
144
)
145
146
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
147
num_patches = output_size ** 2
148
149
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
150
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
151
self.dropout = nn.Dropout(emb_dropout)
152
153
layers = []
154
155
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
156
not_last = ind < (len(depth) - 1)
157
158
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
159
160
if not_last:
161
layers.append(Pool(dim))
162
dim *= 2
163
164
self.layers = nn.Sequential(*layers)
165
166
self.mlp_head = nn.Sequential(
167
nn.LayerNorm(dim),
168
nn.Linear(dim, num_classes)
169
)
170
171
def forward(self, img):
172
x = self.to_patch_embedding(img)
173
b, n, _ = x.shape
174
175
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
176
x = torch.cat((cls_tokens, x), dim=1)
177
x += self.pos_embedding[:, :n+1]
178
x = self.dropout(x)
179
180
x = self.layers(x)
181
182
return self.mlp_head(x[:, 0])
183
184