Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/tests/test_vit.py
649 views
1
import torch
2
from vit_pytorch import ViT
3
4
def test_vit():
5
v = ViT(
6
image_size = 256,
7
patch_size = 32,
8
num_classes = 1000,
9
dim = 1024,
10
depth = 6,
11
heads = 16,
12
mlp_dim = 2048,
13
dropout = 0.1,
14
emb_dropout = 0.1
15
)
16
17
img = torch.randn(1, 3, 256, 256)
18
19
preds = v(img)
20
assert preds.shape == (1, 1000), 'correct logits outputted'
21
22