Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/assembler.py
1192 views
1
import yaml
2
import importlib
3
from functools import partial
4
from models.base_gan import GAN
5
6
7
def get_config(fpath):
8
with open(fpath, "r") as file:
9
try:
10
config = yaml.safe_load(file)
11
except yaml.YAMLError as exc:
12
print(exc)
13
return config
14
15
16
def is_config_valid(config):
17
assert config["loss_params"]["latent_dim"] == config["generator_params"]["latent_dim"]
18
19
20
def assembler(config):
21
# Get model name
22
is_config_valid(config)
23
24
# Get model components
25
gan_name = config["exp_params"]["model_name"]
26
componets = importlib.import_module(f"models.{gan_name}")
27
discriminator = componets.Discriminator(**config["discriminator_params"])
28
generator = componets.Generator(**config["generator_params"])
29
disc_loss = partial(componets.disc_loss, config)
30
gen_loss = partial(componets.gen_loss, config)
31
32
# Get sampling components
33
sampling_name = config["exp_params"]["refinement"]
34
if sampling_name is not None:
35
sampler = importlib.import_module(f"sampling.{sampling_name}")
36
sampling = lambda x: sampler.sampling(config["sampling_params"], generator, discriminator, x)
37
else:
38
sampling = None
39
gan = GAN(gan_name, generator, discriminator, gen_loss, disc_loss, sampling, config["optimizer_params"])
40
41
return gan
42
43
44
if __name__ == "__main__":
45
model_names = ["dcgan"]
46
for model_name in model_names:
47
fpath = f"./configs/{model_name}.yaml"
48
config = get_config(fpath)
49
gan = assembler(config)
50
51