Path: blob/master/Bag-Of-Tricks-For-Image-Classification/main.py
3119 views
import torch1from pytorch_lightning import (2Trainer,3seed_everything,4)5from pytorch_lightning.callbacks import ModelCheckpoint6from torchvision.models import (7resnet18,8resnet50,9)1011from model.model import (12LitFood101,13LitFood101KD,14)15from utils.args import get_program_level_args161718def main():19parser = get_program_level_args()20parser = LitFood101.add_model_specific_args(parser)21parser = Trainer.add_argparse_args(parser)2223args = parser.parse_args()24seed_everything(args.seed)2526checkpoint_callback = ModelCheckpoint(monitor="avg_val_acc", mode="max")27trainer = Trainer.from_argparse_args(28args,29deterministic=True,30benchmark=False,31checkpoint_callback=checkpoint_callback,32precision=16 if args.amp_level != "O0" else 32,33)3435# create model36model = resnet18(pretrained=True)37if args.use_knowledge_distillation:38teacher_model = resnet50(pretrained=False)39model = LitFood101KD(model, teacher_model, args)40else:41model = LitFood101(model, args)4243if args.evaluate:44checkpoint = torch.load(args.checkpoint)45model.load_state_dict(checkpoint["state_dict"])46trainer.test(model, test_dataloaders=model.test_dataloader())47return 04849trainer.fit(model)5051trainer.test()525354if __name__ == "__main__":55main()565758