Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/bit_diffusion.py
1448 views
1
from typing import Optional, Tuple, Union
2
3
import torch
4
from einops import rearrange, reduce
5
6
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
7
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
8
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
9
10
11
BITS = 8
12
13
14
# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
15
def decimal_to_bits(x, bits=BITS):
16
"""expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
17
device = x.device
18
19
x = (x * 255).int().clamp(0, 255)
20
21
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
22
mask = rearrange(mask, "d -> d 1 1")
23
x = rearrange(x, "b c h w -> b c 1 h w")
24
25
bits = ((x & mask) != 0).float()
26
bits = rearrange(bits, "b c d h w -> b (c d) h w")
27
bits = bits * 2 - 1
28
return bits
29
30
31
def bits_to_decimal(x, bits=BITS):
32
"""expects bits from -1 to 1, outputs image tensor from 0 to 1"""
33
device = x.device
34
35
x = (x > 0).int()
36
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
37
38
mask = rearrange(mask, "d -> d 1 1")
39
x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
40
dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
41
return (dec / 255).clamp(0.0, 1.0)
42
43
44
# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
45
def ddim_bit_scheduler_step(
46
self,
47
model_output: torch.FloatTensor,
48
timestep: int,
49
sample: torch.FloatTensor,
50
eta: float = 0.0,
51
use_clipped_model_output: bool = True,
52
generator=None,
53
return_dict: bool = True,
54
) -> Union[DDIMSchedulerOutput, Tuple]:
55
"""
56
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
57
process from the learned model outputs (most often the predicted noise).
58
Args:
59
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
60
timestep (`int`): current discrete timestep in the diffusion chain.
61
sample (`torch.FloatTensor`):
62
current instance of sample being created by diffusion process.
63
eta (`float`): weight of noise for added noise in diffusion step.
64
use_clipped_model_output (`bool`): TODO
65
generator: random number generator.
66
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
67
Returns:
68
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
69
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
70
returning a tuple, the first element is the sample tensor.
71
"""
72
if self.num_inference_steps is None:
73
raise ValueError(
74
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
75
)
76
77
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
78
# Ideally, read DDIM paper in-detail understanding
79
80
# Notation (<variable name> -> <name in paper>
81
# - pred_noise_t -> e_theta(x_t, t)
82
# - pred_original_sample -> f_theta(x_t, t) or x_0
83
# - std_dev_t -> sigma_t
84
# - eta -> η
85
# - pred_sample_direction -> "direction pointing to x_t"
86
# - pred_prev_sample -> "x_t-1"
87
88
# 1. get previous step value (=t-1)
89
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
90
91
# 2. compute alphas, betas
92
alpha_prod_t = self.alphas_cumprod[timestep]
93
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
94
95
beta_prod_t = 1 - alpha_prod_t
96
97
# 3. compute predicted original sample from predicted noise also called
98
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
99
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
100
101
# 4. Clip "predicted x_0"
102
scale = self.bit_scale
103
if self.config.clip_sample:
104
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
105
106
# 5. compute variance: "sigma_t(η)" -> see formula (16)
107
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
108
variance = self._get_variance(timestep, prev_timestep)
109
std_dev_t = eta * variance ** (0.5)
110
111
if use_clipped_model_output:
112
# the model_output is always re-derived from the clipped x_0 in Glide
113
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
114
115
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
116
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
117
118
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
119
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
120
121
if eta > 0:
122
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
123
device = model_output.device if torch.is_tensor(model_output) else "cpu"
124
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
125
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
126
127
prev_sample = prev_sample + variance
128
129
if not return_dict:
130
return (prev_sample,)
131
132
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
133
134
135
def ddpm_bit_scheduler_step(
136
self,
137
model_output: torch.FloatTensor,
138
timestep: int,
139
sample: torch.FloatTensor,
140
prediction_type="epsilon",
141
generator=None,
142
return_dict: bool = True,
143
) -> Union[DDPMSchedulerOutput, Tuple]:
144
"""
145
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
146
process from the learned model outputs (most often the predicted noise).
147
Args:
148
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
149
timestep (`int`): current discrete timestep in the diffusion chain.
150
sample (`torch.FloatTensor`):
151
current instance of sample being created by diffusion process.
152
prediction_type (`str`, default `epsilon`):
153
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
154
generator: random number generator.
155
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
156
Returns:
157
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
158
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
159
returning a tuple, the first element is the sample tensor.
160
"""
161
t = timestep
162
163
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
164
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
165
else:
166
predicted_variance = None
167
168
# 1. compute alphas, betas
169
alpha_prod_t = self.alphas_cumprod[t]
170
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
171
beta_prod_t = 1 - alpha_prod_t
172
beta_prod_t_prev = 1 - alpha_prod_t_prev
173
174
# 2. compute predicted original sample from predicted noise also called
175
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
176
if prediction_type == "epsilon":
177
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
178
elif prediction_type == "sample":
179
pred_original_sample = model_output
180
else:
181
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
182
183
# 3. Clip "predicted x_0"
184
scale = self.bit_scale
185
if self.config.clip_sample:
186
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
187
188
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
189
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
190
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
191
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
192
193
# 5. Compute predicted previous sample µ_t
194
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
195
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
196
197
# 6. Add noise
198
variance = 0
199
if t > 0:
200
noise = torch.randn(
201
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
202
).to(model_output.device)
203
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
204
205
pred_prev_sample = pred_prev_sample + variance
206
207
if not return_dict:
208
return (pred_prev_sample,)
209
210
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
211
212
213
class BitDiffusion(DiffusionPipeline):
214
def __init__(
215
self,
216
unet: UNet2DConditionModel,
217
scheduler: Union[DDIMScheduler, DDPMScheduler],
218
bit_scale: Optional[float] = 1.0,
219
):
220
super().__init__()
221
self.bit_scale = bit_scale
222
self.scheduler.step = (
223
ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
224
)
225
226
self.register_modules(unet=unet, scheduler=scheduler)
227
228
@torch.no_grad()
229
def __call__(
230
self,
231
height: Optional[int] = 256,
232
width: Optional[int] = 256,
233
num_inference_steps: Optional[int] = 50,
234
generator: Optional[torch.Generator] = None,
235
batch_size: Optional[int] = 1,
236
output_type: Optional[str] = "pil",
237
return_dict: bool = True,
238
**kwargs,
239
) -> Union[Tuple, ImagePipelineOutput]:
240
latents = torch.randn(
241
(batch_size, self.unet.in_channels, height, width),
242
generator=generator,
243
)
244
latents = decimal_to_bits(latents) * self.bit_scale
245
latents = latents.to(self.device)
246
247
self.scheduler.set_timesteps(num_inference_steps)
248
249
for t in self.progress_bar(self.scheduler.timesteps):
250
# predict the noise residual
251
noise_pred = self.unet(latents, t).sample
252
253
# compute the previous noisy sample x_t -> x_t-1
254
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
255
256
image = bits_to_decimal(latents)
257
258
if output_type == "pil":
259
image = self.numpy_to_pil(image)
260
261
if not return_dict:
262
return (image,)
263
264
return ImagePipelineOutput(images=image)
265
266