Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Bag-Of-Tricks-For-Image-Classification/main.py
3119 views
1
import torch
2
from pytorch_lightning import (
3
Trainer,
4
seed_everything,
5
)
6
from pytorch_lightning.callbacks import ModelCheckpoint
7
from torchvision.models import (
8
resnet18,
9
resnet50,
10
)
11
12
from model.model import (
13
LitFood101,
14
LitFood101KD,
15
)
16
from utils.args import get_program_level_args
17
18
19
def main():
20
parser = get_program_level_args()
21
parser = LitFood101.add_model_specific_args(parser)
22
parser = Trainer.add_argparse_args(parser)
23
24
args = parser.parse_args()
25
seed_everything(args.seed)
26
27
checkpoint_callback = ModelCheckpoint(monitor="avg_val_acc", mode="max")
28
trainer = Trainer.from_argparse_args(
29
args,
30
deterministic=True,
31
benchmark=False,
32
checkpoint_callback=checkpoint_callback,
33
precision=16 if args.amp_level != "O0" else 32,
34
)
35
36
# create model
37
model = resnet18(pretrained=True)
38
if args.use_knowledge_distillation:
39
teacher_model = resnet50(pretrained=False)
40
model = LitFood101KD(model, teacher_model, args)
41
else:
42
model = LitFood101(model, args)
43
44
if args.evaluate:
45
checkpoint = torch.load(args.checkpoint)
46
model.load_state_dict(checkpoint["state_dict"])
47
trainer.test(model, test_dataloaders=model.test_dataloader())
48
return 0
49
50
trainer.fit(model)
51
52
trainer.test()
53
54
55
if __name__ == "__main__":
56
main()
57
58