Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/gan/models/base_gan.py
1192 views
1
import torch
2
from torch import nn, Tensor
3
from typing import Any, Callable
4
from pytorch_lightning import LightningModule
5
6
7
class GAN(LightningModule):
8
"""
9
DCGAN implementation.
10
Example::
11
from pl_bolts.models.gans import DCGAN
12
m = DCGAN()
13
Trainer(gpus=2).fit(m)
14
Example CLI::
15
# mnist
16
python dcgan_module.py --gpus 1
17
# cifar10
18
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
19
"""
20
21
def __init__(
22
self,
23
name: str,
24
generator: Callable,
25
discriminator: Callable,
26
gen_loss: Callable,
27
disc_loss: Callable,
28
sampling: Callable,
29
config: dict,
30
) -> None:
31
"""
32
Args:
33
beta1: Beta1 value for Adam optimizer
34
feature_maps_gen: Number of feature maps to use for the generator
35
feature_maps_disc: Number of feature maps to use for the discriminator
36
image_channels: Number of channels of the images from the dataset
37
latent_dim: Dimension of the latent space
38
learning_rate: Learning rate
39
"""
40
super().__init__()
41
self.save_hyperparameters()
42
43
self.name = name
44
self.generator = generator
45
self.discriminator = discriminator
46
self.learning_rate = config["learning_rate"]
47
self.beta1 = config["beta1"]
48
self.sampling = sampling
49
self.gen_loss = lambda num, real: gen_loss(discriminator, generator, num, real)
50
self.disc_loss = lambda num, real: disc_loss(discriminator, generator, num, real)
51
52
@staticmethod
53
def _weights_init(m):
54
classname = m.__class__.__name__
55
if classname.find("Conv") != -1:
56
torch.nn.init.normal_(m.weight, 0.0, 0.02)
57
elif classname.find("BatchNorm") != -1:
58
torch.nn.init.normal_(m.weight, 1.0, 0.02)
59
torch.nn.init.zeros_(m.bias)
60
61
def configure_optimizers(self):
62
lr = self.learning_rate
63
betas = (self.beta1, 0.999)
64
opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas)
65
opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
66
return [opt_disc, opt_gen], []
67
68
def forward(self, noise: Tensor) -> Tensor:
69
"""
70
Generates an image given input noise
71
Example::
72
noise = torch.rand(batch_size, latent_dim)
73
gan = GAN.load_from_checkpoint(PATH)
74
img = gan(noise)
75
"""
76
noise = noise.to(self.device)
77
noise = noise.view(*noise.shape, 1, 1)
78
if self.sampling is not None:
79
noise = self.sampling(noise)
80
return self.generator(noise)
81
82
def training_step(self, batch, batch_idx, optimizer_idx):
83
real, _ = batch
84
85
# Train discriminator
86
result = None
87
if optimizer_idx == 0:
88
result = self._disc_step(real)
89
90
# Train generator
91
if optimizer_idx == 1:
92
result = self._gen_step(real)
93
94
return result
95
96
def _disc_step(self, real: Tensor) -> Tensor:
97
disc_loss = self.disc_loss(self.trainer.current_epoch + 1, real)
98
self.log("loss/disc", disc_loss, on_epoch=True)
99
return disc_loss
100
101
def _gen_step(self, real: Tensor) -> Tensor:
102
gen_loss = self.gen_loss(self.trainer.current_epoch + 1, real)
103
self.log("loss/gen", gen_loss, on_epoch=True)
104
return gen_loss
105
106
def _get_noise(self, n_samples: int, latent_dim: int) -> Tensor:
107
return torch.randn(n_samples, latent_dim, device=self.device)
108
109
def load_model(self):
110
try:
111
self.load_state_dict(torch.load(f"{self.name}_celeba.ckpt"))
112
except FileNotFoundError:
113
print(f"Please train the model using python run.py -c ./configs/{self.model.name}.yaml")
114
115