Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/run_pixel.py
1192 views
1
import torch
2
import argparse
3
from assembler import get_config, assembler
4
from data import CelebADataModule
5
from pytorch_lightning import Trainer
6
import torchvision.transforms as transforms
7
8
# Load configs
9
from models.pixel_cnn import *
10
from experiment import *
11
12
parser = argparse.ArgumentParser(description="Generic runner for VAE models")
13
parser.add_argument(
14
"--config", "-c", dest="filename", metavar="FILE", help="path to the config file", default="configs/vae.yaml"
15
)
16
17
args = parser.parse_args()
18
config = get_config(args.filename)
19
vae = assembler(config, "training")
20
21
# Load data
22
trans = []
23
trans.append(transforms.RandomHorizontalFlip())
24
if config["exp_params"]["crop_size"] > 0:
25
trans.append(transforms.CenterCrop(config["exp_params"]["crop_size"]))
26
trans.append(transforms.Resize(config["exp_params"]["img_size"]))
27
trans.append(transforms.ToTensor())
28
transform = transforms.Compose(trans)
29
30
dm = CelebADataModule(
31
data_dir=config["exp_params"]["data_path"],
32
target_type="attr",
33
train_transform=transform,
34
val_transform=transform,
35
download=True,
36
batch_size=config["exp_params"]["batch_size"],
37
)
38
39
40
vae.load_state_dict(torch.load(config["pixel_params"]["pretrained_path"]))
41
42
num_residual_blocks = config["pixel_params"]["num_residual_blocks"]
43
num_pixelcnn_layers = config["pixel_params"]["num_pixelcnn_layers"]
44
num_embeddings = config["vq_params"]["num_embeddings"]
45
hidden_dim = config["pixel_params"]["hidden_dim"]
46
47
# Run Training Loop
48
trainer = Trainer(gpus=config["trainer_params"]["gpus"], max_epochs=config["trainer_params"]["max_epochs"])
49
50
pixel_cnn_raw = PixelCNN(hidden_dim, num_residual_blocks, num_pixelcnn_layers, num_embeddings)
51
pixel_cnn = PixelCNNModule(
52
pixel_cnn_raw, vae, config["pixel_params"]["height"], config["pixel_params"]["width"], config["pixel_params"]["LR"]
53
)
54
trainer.fit(pixel_cnn, datamodule=dm)
55
pixel_cnn.save(config["pixel_params"]["save_path"])
56
57