Path: blob/main/examples/community/bit_diffusion.py
1448 views
from typing import Optional, Tuple, Union12import torch3from einops import rearrange, reduce45from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel6from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput7from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput8910BITS = 8111213# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py14def decimal_to_bits(x, bits=BITS):15"""expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""16device = x.device1718x = (x * 255).int().clamp(0, 255)1920mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)21mask = rearrange(mask, "d -> d 1 1")22x = rearrange(x, "b c h w -> b c 1 h w")2324bits = ((x & mask) != 0).float()25bits = rearrange(bits, "b c d h w -> b (c d) h w")26bits = bits * 2 - 127return bits282930def bits_to_decimal(x, bits=BITS):31"""expects bits from -1 to 1, outputs image tensor from 0 to 1"""32device = x.device3334x = (x > 0).int()35mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)3637mask = rearrange(mask, "d -> d 1 1")38x = rearrange(x, "b (c d) h w -> b c d h w", d=8)39dec = reduce(x * mask, "b c d h w -> b c h w", "sum")40return (dec / 255).clamp(0.0, 1.0)414243# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale44def ddim_bit_scheduler_step(45self,46model_output: torch.FloatTensor,47timestep: int,48sample: torch.FloatTensor,49eta: float = 0.0,50use_clipped_model_output: bool = True,51generator=None,52return_dict: bool = True,53) -> Union[DDIMSchedulerOutput, Tuple]:54"""55Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion56process from the learned model outputs (most often the predicted noise).57Args:58model_output (`torch.FloatTensor`): direct output from learned diffusion model.59timestep (`int`): current discrete timestep in the diffusion chain.60sample (`torch.FloatTensor`):61current instance of sample being created by diffusion process.62eta (`float`): weight of noise for added noise in diffusion step.63use_clipped_model_output (`bool`): TODO64generator: random number generator.65return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class66Returns:67[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:68[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When69returning a tuple, the first element is the sample tensor.70"""71if self.num_inference_steps is None:72raise ValueError(73"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"74)7576# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf77# Ideally, read DDIM paper in-detail understanding7879# Notation (<variable name> -> <name in paper>80# - pred_noise_t -> e_theta(x_t, t)81# - pred_original_sample -> f_theta(x_t, t) or x_082# - std_dev_t -> sigma_t83# - eta -> η84# - pred_sample_direction -> "direction pointing to x_t"85# - pred_prev_sample -> "x_t-1"8687# 1. get previous step value (=t-1)88prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps8990# 2. compute alphas, betas91alpha_prod_t = self.alphas_cumprod[timestep]92alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod9394beta_prod_t = 1 - alpha_prod_t9596# 3. compute predicted original sample from predicted noise also called97# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf98pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)99100# 4. Clip "predicted x_0"101scale = self.bit_scale102if self.config.clip_sample:103pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)104105# 5. compute variance: "sigma_t(η)" -> see formula (16)106# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)107variance = self._get_variance(timestep, prev_timestep)108std_dev_t = eta * variance ** (0.5)109110if use_clipped_model_output:111# the model_output is always re-derived from the clipped x_0 in Glide112model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)113114# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf115pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output116117# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf118prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction119120if eta > 0:121# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072122device = model_output.device if torch.is_tensor(model_output) else "cpu"123noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)124variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise125126prev_sample = prev_sample + variance127128if not return_dict:129return (prev_sample,)130131return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)132133134def ddpm_bit_scheduler_step(135self,136model_output: torch.FloatTensor,137timestep: int,138sample: torch.FloatTensor,139prediction_type="epsilon",140generator=None,141return_dict: bool = True,142) -> Union[DDPMSchedulerOutput, Tuple]:143"""144Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion145process from the learned model outputs (most often the predicted noise).146Args:147model_output (`torch.FloatTensor`): direct output from learned diffusion model.148timestep (`int`): current discrete timestep in the diffusion chain.149sample (`torch.FloatTensor`):150current instance of sample being created by diffusion process.151prediction_type (`str`, default `epsilon`):152indicates whether the model predicts the noise (epsilon), or the samples (`sample`).153generator: random number generator.154return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class155Returns:156[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:157[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When158returning a tuple, the first element is the sample tensor.159"""160t = timestep161162if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:163model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)164else:165predicted_variance = None166167# 1. compute alphas, betas168alpha_prod_t = self.alphas_cumprod[t]169alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one170beta_prod_t = 1 - alpha_prod_t171beta_prod_t_prev = 1 - alpha_prod_t_prev172173# 2. compute predicted original sample from predicted noise also called174# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf175if prediction_type == "epsilon":176pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)177elif prediction_type == "sample":178pred_original_sample = model_output179else:180raise ValueError(f"Unsupported prediction_type {prediction_type}.")181182# 3. Clip "predicted x_0"183scale = self.bit_scale184if self.config.clip_sample:185pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)186187# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t188# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf189pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t190current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t191192# 5. Compute predicted previous sample µ_t193# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf194pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample195196# 6. Add noise197variance = 0198if t > 0:199noise = torch.randn(200model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator201).to(model_output.device)202variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise203204pred_prev_sample = pred_prev_sample + variance205206if not return_dict:207return (pred_prev_sample,)208209return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)210211212class BitDiffusion(DiffusionPipeline):213def __init__(214self,215unet: UNet2DConditionModel,216scheduler: Union[DDIMScheduler, DDPMScheduler],217bit_scale: Optional[float] = 1.0,218):219super().__init__()220self.bit_scale = bit_scale221self.scheduler.step = (222ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step223)224225self.register_modules(unet=unet, scheduler=scheduler)226227@torch.no_grad()228def __call__(229self,230height: Optional[int] = 256,231width: Optional[int] = 256,232num_inference_steps: Optional[int] = 50,233generator: Optional[torch.Generator] = None,234batch_size: Optional[int] = 1,235output_type: Optional[str] = "pil",236return_dict: bool = True,237**kwargs,238) -> Union[Tuple, ImagePipelineOutput]:239latents = torch.randn(240(batch_size, self.unet.in_channels, height, width),241generator=generator,242)243latents = decimal_to_bits(latents) * self.bit_scale244latents = latents.to(self.device)245246self.scheduler.set_timesteps(num_inference_steps)247248for t in self.progress_bar(self.scheduler.timesteps):249# predict the noise residual250noise_pred = self.unet(latents, t).sample251252# compute the previous noisy sample x_t -> x_t-1253latents = self.scheduler.step(noise_pred, t, latents).prev_sample254255image = bits_to_decimal(latents)256257if output_type == "pil":258image = self.numpy_to_pil(image)259260if not return_dict:261return (image,)262263return ImagePipelineOutput(images=image)264265266