Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/train_vit_decorr.py
645 views
1
# /// script
2
# dependencies = [
3
# "accelerate",
4
# "vit-pytorch",
5
# "wandb"
6
# ]
7
# ///
8
9
import torch
10
import torch.nn.functional as F
11
from torch.utils.data import DataLoader
12
13
import torchvision.transforms as T
14
from torchvision.datasets import CIFAR100
15
16
# constants
17
18
BATCH_SIZE = 32
19
LEARNING_RATE = 3e-4
20
EPOCHS = 10
21
DECORR_LOSS_WEIGHT = 1e-1
22
23
TRACK_EXPERIMENT_ONLINE = False
24
25
# helpers
26
27
def exists(v):
28
return v is not None
29
30
# data
31
32
transform = T.Compose([
33
T.ToTensor(),
34
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35
])
36
37
dataset = CIFAR100(
38
root = 'data',
39
download = True,
40
train = True,
41
transform = transform
42
)
43
44
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
45
46
# model
47
48
from vit_pytorch.vit_with_decorr import ViT
49
50
vit = ViT(
51
dim = 128,
52
num_classes = 100,
53
image_size = 32,
54
patch_size = 4,
55
depth = 6,
56
heads = 8,
57
dim_head = 64,
58
mlp_dim = 128 * 4,
59
decorr_sample_frac = 1. # use all tokens
60
)
61
62
# optim
63
64
from torch.optim import Adam
65
66
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
67
68
# prepare
69
70
from accelerate import Accelerator
71
72
accelerator = Accelerator()
73
74
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
75
76
# experiment
77
78
import wandb
79
80
wandb.init(
81
project = 'vit-decorr',
82
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
83
)
84
85
wandb.run.name = 'baseline'
86
87
# loop
88
89
for _ in range(EPOCHS):
90
for images, labels in dataloader:
91
92
logits, decorr_aux_loss = vit(images)
93
loss = F.cross_entropy(logits, labels)
94
95
96
total_loss = (
97
loss +
98
decorr_aux_loss * DECORR_LOSS_WEIGHT
99
)
100
101
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
102
103
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
104
105
accelerator.backward(total_loss)
106
optim.step()
107
optim.zero_grad()
108
109