Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/vae/assembler.py
1192 views
1
import yaml
2
import importlib
3
from functools import partial
4
from models.guassian_vae import VAE
5
from models.two_stage_vae import Stage2VAE
6
from models.vq_vae import VQVAE
7
from models.pixel_cnn import PixelCNN
8
from experiment import VAEModule, VAE2stageModule, VQVAEModule, PixelCNNModule
9
10
11
def get_config(fpath):
12
with open(fpath, "r") as file:
13
try:
14
config = yaml.safe_load(file)
15
except yaml.YAMLError as exc:
16
print(exc)
17
return config
18
19
20
def is_vq_vae(config):
21
return config["exp_params"]["template"] == "vq vae"
22
23
24
def is_two_stage(config):
25
return config["exp_params"]["template"] == "2 stage vae"
26
27
28
def is_default_vae(config):
29
return config["exp_params"]["template"] == "default vae"
30
31
32
def is_config_valid(config):
33
# Check config is valid
34
assert type(config["exp_params"]["model_name"]) == str
35
assert type(config["exp_params"]["template"]) == str
36
assert config["encoder_params"]["latent_dim"] == config["decoder_params"]["latent_dim"]
37
if is_two_stage(config):
38
assert type(config["stage1_params"]["model"]) == str
39
model_name = config["stage1_params"]["model"]
40
fpath = f"./configs/{model_name}.yaml"
41
config_stage_one = get_config(fpath)
42
is_config_valid(config_stage_one)
43
assert config_stage_one["encoder_params"]["latent_dim"] == config["encoder_params"]["input_dim"]
44
assert config["decoder_params"]["output_dim"] == config["encoder_params"]["input_dim"]
45
elif is_vq_vae(config):
46
assert config["encoder_params"]["latent_dim"] == config["vq_params"]["embedding_dim"]
47
48
49
def is_mode_training(mode):
50
return mode == "training"
51
52
53
def is_mode_inference(mode):
54
return mode == "inference"
55
56
57
def get_first_stage_vae(config):
58
model_name = config["stage1_params"]["model"]
59
config = get_config(f"./configs/{model_name}.yaml")
60
vae = assembler(config, "inference")
61
vae.load_model()
62
return vae
63
64
65
def compose_for_inference(models):
66
if len(models) == 0:
67
raise "empty model list"
68
elif len(models) == 1:
69
return models[0]
70
elif len(models) == 2:
71
return VAE2stageModule(models[0], models[1])
72
else:
73
vae = compose_for_inference(models[:-1])
74
return VAE2stageModule(vae, models[-1])
75
76
77
def assembler(config, mode):
78
# Get model name
79
is_config_valid(config)
80
81
# Get model components
82
vae_name = config["exp_params"]["model_name"]
83
componets = importlib.import_module(f"models.{vae_name}")
84
encoder = componets.Encoder(**config["encoder_params"])
85
decoder = componets.Decoder(**config["decoder_params"])
86
loss = partial(componets.loss, config["loss_params"])
87
88
# Assemble my model
89
if is_default_vae(config):
90
vae = VAE(vae_name, loss, encoder, decoder)
91
vae = VAEModule(vae, config["exp_params"]["LR"], config["encoder_params"]["latent_dim"])
92
vaes = [vae]
93
elif is_two_stage(config):
94
vae_first_stage = get_first_stage_vae(config)
95
vae = Stage2VAE(vae_name, loss, encoder, decoder, vae_first_stage)
96
vae = VAEModule(vae, config["exp_params"]["LR"], config["encoder_params"]["latent_dim"])
97
vaes = [vae_first_stage, vae]
98
elif is_vq_vae(config):
99
vae = VQVAE(vae_name, loss, encoder, decoder, config["vq_params"])
100
vae = VQVAEModule(vae, config)
101
vaes = [vae]
102
103
# training vs inference time model
104
if is_mode_training(mode):
105
vae = vaes[-1]
106
elif is_mode_inference(mode):
107
if is_two_stage(config):
108
vae = compose_for_inference(vaes)
109
110
return vae
111
112
113
if __name__ == "__main__":
114
model_names = ["hinge_vae", "two_stage_vae"]
115
for model_name in model_names:
116
fpath = f"./configs/{model_name}.yaml"
117
config = get_config(fpath)
118
vae = assembler(config, "training")
119
vae = assembler(config, "inference")
120
121