Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/run.py
1192 views
1
import torch
2
import argparse
3
from assembler import get_config, assembler
4
from download_celeba import celeba_dataloader
5
from pytorch_lightning import Trainer
6
import torchvision.transforms as transforms
7
8
# Load configs
9
parser = argparse.ArgumentParser(description="Generic runner for VAE models")
10
parser.add_argument(
11
"--config", "-c", dest="filename", metavar="FILE", help="path to the config file", default="configs/vae.yaml"
12
)
13
14
args = parser.parse_args()
15
config = get_config(args.filename)
16
gan = assembler(config)
17
18
# Load data
19
dm = celeba_dataloader(
20
config["exp_params"]["batch_size"],
21
config["exp_params"]["img_size"],
22
config["exp_params"]["crop_size"],
23
config["exp_params"]["data_path"],
24
)
25
26
27
# Run Training Loop
28
trainer = Trainer(gpus=config["trainer_params"]["gpus"], max_epochs=config["trainer_params"]["max_epochs"])
29
trainer.fit(gan, datamodule=dm)
30
torch.save(gan.state_dict(), f"{gan.name}_celeba.ckpt")
31
32