Path: blob/master/labml_nn/diffusion/ddpm/experiment.py
4921 views
"""1---2title: Denoising Diffusion Probabilistic Models (DDPM) training3summary: >4Training code for5Denoising Diffusion Probabilistic Model.6---78# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training910[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/experiment.ipynb)1112This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this13[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).14Save the images inside [`data/celebA` folder](#dataset_path).1516The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for17simplicity.18"""19from typing import List2021import torchvision22from PIL import Image2324import torch25import torch.utils.data26from labml import lab, tracker, experiment, monit27from labml.configs import BaseConfigs, option28from labml_nn.diffusion.ddpm import DenoiseDiffusion29from labml_nn.diffusion.ddpm.unet import UNet30from labml_nn.helpers.device import DeviceConfigs313233class Configs(BaseConfigs):34"""35## Configurations36"""37# Device to train the model on.38# [`DeviceConfigs`](../../device.html)39# picks up an available CUDA device or defaults to CPU.40device: torch.device = DeviceConfigs()4142# U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$43eps_model: UNet44# [DDPM algorithm](index.html)45diffusion: DenoiseDiffusion4647# Number of channels in the image. $3$ for RGB.48image_channels: int = 349# Image size50image_size: int = 3251# Number of channels in the initial feature map52n_channels: int = 6453# The list of channel numbers at each resolution.54# The number of channels is `channel_multipliers[i] * n_channels`55channel_multipliers: List[int] = [1, 2, 2, 4]56# The list of booleans that indicate whether to use attention at each resolution57is_attention: List[int] = [False, False, False, True]5859# Number of time steps $T$60n_steps: int = 1_00061# Batch size62batch_size: int = 6463# Number of samples to generate64n_samples: int = 1665# Learning rate66learning_rate: float = 2e-56768# Number of training epochs69epochs: int = 1_0007071# Dataset72dataset: torch.utils.data.Dataset73# Dataloader74data_loader: torch.utils.data.DataLoader7576# Adam optimizer77optimizer: torch.optim.Adam7879def init(self):80# Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model81self.eps_model = UNet(82image_channels=self.image_channels,83n_channels=self.n_channels,84ch_mults=self.channel_multipliers,85is_attn=self.is_attention,86).to(self.device)8788# Create [DDPM class](index.html)89self.diffusion = DenoiseDiffusion(90eps_model=self.eps_model,91n_steps=self.n_steps,92device=self.device,93)9495# Create dataloader96self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)97# Create optimizer98self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)99100# Image logging101tracker.set_image("sample", True)102103def sample(self):104"""105### Sample images106"""107with torch.no_grad():108# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$109x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],110device=self.device)111112# Remove noise for $T$ steps113for t_ in monit.iterate('Sample', self.n_steps):114# $t$115t = self.n_steps - t_ - 1116# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$117x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))118119# Log samples120tracker.save('sample', x)121122def train(self):123"""124### Train125"""126127# Iterate through the dataset128for data in monit.iterate('Train', self.data_loader):129# Increment global step130tracker.add_global_step()131# Move data to device132data = data.to(self.device)133134# Make the gradients zero135self.optimizer.zero_grad()136# Calculate loss137loss = self.diffusion.loss(data)138# Compute gradients139loss.backward()140# Take an optimization step141self.optimizer.step()142# Track the loss143tracker.save('loss', loss)144145def run(self):146"""147### Training loop148"""149for _ in monit.loop(self.epochs):150# Train the model151self.train()152# Sample some images153self.sample()154# New line in the console155tracker.new_line()156157158class CelebADataset(torch.utils.data.Dataset):159"""160### CelebA HQ dataset161"""162163def __init__(self, image_size: int):164super().__init__()165166# CelebA images folder167folder = lab.get_data_path() / 'celebA'168# List of files169self._files = [p for p in folder.glob(f'**/*.jpg')]170171# Transformations to resize the image and convert to tensor172self._transform = torchvision.transforms.Compose([173torchvision.transforms.Resize(image_size),174torchvision.transforms.ToTensor(),175])176177def __len__(self):178"""179Size of the dataset180"""181return len(self._files)182183def __getitem__(self, index: int):184"""185Get an image186"""187img = Image.open(self._files[index])188return self._transform(img)189190191@option(Configs.dataset, 'CelebA')192def celeb_dataset(c: Configs):193"""194Create CelebA dataset195"""196return CelebADataset(c.image_size)197198199class MNISTDataset(torchvision.datasets.MNIST):200"""201### MNIST dataset202"""203204def __init__(self, image_size):205transform = torchvision.transforms.Compose([206torchvision.transforms.Resize(image_size),207torchvision.transforms.ToTensor(),208])209210super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)211212def __getitem__(self, item):213return super().__getitem__(item)[0]214215216@option(Configs.dataset, 'MNIST')217def mnist_dataset(c: Configs):218"""219Create MNIST dataset220"""221return MNISTDataset(c.image_size)222223224def main():225# Create experiment226experiment.create(name='diffuse', writers={'screen', 'labml'})227228# Create configurations229configs = Configs()230231# Set configurations. You can override the defaults by passing the values in the dictionary.232experiment.configs(configs, {233'dataset': 'CelebA', # 'MNIST'234'image_channels': 3, # 1,235'epochs': 100, # 5,236})237238# Initialize239configs.init()240241# Set models for saving and loading242experiment.add_pytorch_models({'eps_model': configs.eps_model})243244# Start and run the training loop245with experiment.start():246configs.run()247248249#250if __name__ == '__main__':251main()252253254