Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/unclip_text_interpolation.py
1448 views
1
import inspect
2
from typing import List, Optional, Tuple, Union
3
4
import torch
5
from torch.nn import functional as F
6
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
7
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
8
9
from diffusers import (
10
DiffusionPipeline,
11
ImagePipelineOutput,
12
PriorTransformer,
13
UnCLIPScheduler,
14
UNet2DConditionModel,
15
UNet2DModel,
16
)
17
from diffusers.pipelines.unclip import UnCLIPTextProjModel
18
from diffusers.utils import is_accelerate_available, logging, randn_tensor
19
20
21
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
23
24
def slerp(val, low, high):
25
"""
26
Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
27
"""
28
low_norm = low / torch.norm(low)
29
high_norm = high / torch.norm(high)
30
omega = torch.acos((low_norm * high_norm))
31
so = torch.sin(omega)
32
res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
33
return res
34
35
36
class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
37
38
"""
39
Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.
40
41
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
42
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
43
44
Args:
45
text_encoder ([`CLIPTextModelWithProjection`]):
46
Frozen text-encoder.
47
tokenizer (`CLIPTokenizer`):
48
Tokenizer of class
49
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
50
prior ([`PriorTransformer`]):
51
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
52
text_proj ([`UnCLIPTextProjModel`]):
53
Utility class to prepare and combine the embeddings before they are passed to the decoder.
54
decoder ([`UNet2DConditionModel`]):
55
The decoder to invert the image embedding into an image.
56
super_res_first ([`UNet2DModel`]):
57
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
58
super_res_last ([`UNet2DModel`]):
59
Super resolution unet. Used in the last step of the super resolution diffusion process.
60
prior_scheduler ([`UnCLIPScheduler`]):
61
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
62
decoder_scheduler ([`UnCLIPScheduler`]):
63
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
64
super_res_scheduler ([`UnCLIPScheduler`]):
65
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
66
67
"""
68
69
prior: PriorTransformer
70
decoder: UNet2DConditionModel
71
text_proj: UnCLIPTextProjModel
72
text_encoder: CLIPTextModelWithProjection
73
tokenizer: CLIPTokenizer
74
super_res_first: UNet2DModel
75
super_res_last: UNet2DModel
76
77
prior_scheduler: UnCLIPScheduler
78
decoder_scheduler: UnCLIPScheduler
79
super_res_scheduler: UnCLIPScheduler
80
81
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__
82
def __init__(
83
self,
84
prior: PriorTransformer,
85
decoder: UNet2DConditionModel,
86
text_encoder: CLIPTextModelWithProjection,
87
tokenizer: CLIPTokenizer,
88
text_proj: UnCLIPTextProjModel,
89
super_res_first: UNet2DModel,
90
super_res_last: UNet2DModel,
91
prior_scheduler: UnCLIPScheduler,
92
decoder_scheduler: UnCLIPScheduler,
93
super_res_scheduler: UnCLIPScheduler,
94
):
95
super().__init__()
96
97
self.register_modules(
98
prior=prior,
99
decoder=decoder,
100
text_encoder=text_encoder,
101
tokenizer=tokenizer,
102
text_proj=text_proj,
103
super_res_first=super_res_first,
104
super_res_last=super_res_last,
105
prior_scheduler=prior_scheduler,
106
decoder_scheduler=decoder_scheduler,
107
super_res_scheduler=super_res_scheduler,
108
)
109
110
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
111
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
112
if latents is None:
113
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
114
else:
115
if latents.shape != shape:
116
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
117
latents = latents.to(device)
118
119
latents = latents * scheduler.init_noise_sigma
120
return latents
121
122
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
123
def _encode_prompt(
124
self,
125
prompt,
126
device,
127
num_images_per_prompt,
128
do_classifier_free_guidance,
129
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
130
text_attention_mask: Optional[torch.Tensor] = None,
131
):
132
if text_model_output is None:
133
batch_size = len(prompt) if isinstance(prompt, list) else 1
134
# get prompt text embeddings
135
text_inputs = self.tokenizer(
136
prompt,
137
padding="max_length",
138
max_length=self.tokenizer.model_max_length,
139
truncation=True,
140
return_tensors="pt",
141
)
142
text_input_ids = text_inputs.input_ids
143
text_mask = text_inputs.attention_mask.bool().to(device)
144
145
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
146
147
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
148
text_input_ids, untruncated_ids
149
):
150
removed_text = self.tokenizer.batch_decode(
151
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
152
)
153
logger.warning(
154
"The following part of your input was truncated because CLIP can only handle sequences up to"
155
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
156
)
157
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
158
159
text_encoder_output = self.text_encoder(text_input_ids.to(device))
160
161
prompt_embeds = text_encoder_output.text_embeds
162
text_encoder_hidden_states = text_encoder_output.last_hidden_state
163
164
else:
165
batch_size = text_model_output[0].shape[0]
166
prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
167
text_mask = text_attention_mask
168
169
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
170
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
171
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
172
173
if do_classifier_free_guidance:
174
uncond_tokens = [""] * batch_size
175
176
uncond_input = self.tokenizer(
177
uncond_tokens,
178
padding="max_length",
179
max_length=self.tokenizer.model_max_length,
180
truncation=True,
181
return_tensors="pt",
182
)
183
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
184
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
185
186
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
187
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
188
189
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
190
191
seq_len = negative_prompt_embeds.shape[1]
192
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
193
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
194
195
seq_len = uncond_text_encoder_hidden_states.shape[1]
196
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
197
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
198
batch_size * num_images_per_prompt, seq_len, -1
199
)
200
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
201
202
# done duplicates
203
204
# For classifier free guidance, we need to do two forward passes.
205
# Here we concatenate the unconditional and text embeddings into a single batch
206
# to avoid doing two forward passes
207
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
208
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
209
210
text_mask = torch.cat([uncond_text_mask, text_mask])
211
212
return prompt_embeds, text_encoder_hidden_states, text_mask
213
214
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload
215
def enable_sequential_cpu_offload(self, gpu_id=0):
216
r"""
217
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
218
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
219
when their specific submodule has its `forward` method called.
220
"""
221
if is_accelerate_available():
222
from accelerate import cpu_offload
223
else:
224
raise ImportError("Please install accelerate via `pip install accelerate`")
225
226
device = torch.device(f"cuda:{gpu_id}")
227
228
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
229
models = [
230
self.decoder,
231
self.text_proj,
232
self.text_encoder,
233
self.super_res_first,
234
self.super_res_last,
235
]
236
for cpu_offloaded_model in models:
237
if cpu_offloaded_model is not None:
238
cpu_offload(cpu_offloaded_model, device)
239
240
@property
241
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
242
def _execution_device(self):
243
r"""
244
Returns the device on which the pipeline's models will be executed. After calling
245
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
246
hooks.
247
"""
248
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
249
return self.device
250
for module in self.decoder.modules():
251
if (
252
hasattr(module, "_hf_hook")
253
and hasattr(module._hf_hook, "execution_device")
254
and module._hf_hook.execution_device is not None
255
):
256
return torch.device(module._hf_hook.execution_device)
257
return self.device
258
259
@torch.no_grad()
260
def __call__(
261
self,
262
start_prompt: str,
263
end_prompt: str,
264
steps: int = 5,
265
prior_num_inference_steps: int = 25,
266
decoder_num_inference_steps: int = 25,
267
super_res_num_inference_steps: int = 7,
268
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
269
prior_guidance_scale: float = 4.0,
270
decoder_guidance_scale: float = 8.0,
271
enable_sequential_cpu_offload=True,
272
gpu_id=0,
273
output_type: Optional[str] = "pil",
274
return_dict: bool = True,
275
):
276
"""
277
Function invoked when calling the pipeline for generation.
278
279
Args:
280
start_prompt (`str`):
281
The prompt to start the image generation interpolation from.
282
end_prompt (`str`):
283
The prompt to end the image generation interpolation at.
284
steps (`int`, *optional*, defaults to 5):
285
The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns
286
the same number of images as this value.
287
prior_num_inference_steps (`int`, *optional*, defaults to 25):
288
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
289
image at the expense of slower inference.
290
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
291
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
292
image at the expense of slower inference.
293
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
294
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
295
quality image at the expense of slower inference.
296
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
297
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
298
to make generation deterministic.
299
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
300
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
301
`guidance_scale` is defined as `w` of equation 2. of [Imagen
302
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
303
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
304
usually at the expense of lower image quality.
305
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
306
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
307
`guidance_scale` is defined as `w` of equation 2. of [Imagen
308
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
309
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
310
usually at the expense of lower image quality.
311
output_type (`str`, *optional*, defaults to `"pil"`):
312
The output format of the generated image. Choose between
313
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
314
enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):
315
If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
316
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
317
when their specific submodule has its `forward` method called.
318
gpu_id (`int`, *optional*, defaults to `0`):
319
The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.
320
return_dict (`bool`, *optional*, defaults to `True`):
321
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
322
"""
323
324
if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):
325
raise ValueError(
326
f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"
327
f" {type(end_prompt)} instead"
328
)
329
330
if enable_sequential_cpu_offload:
331
self.enable_sequential_cpu_offload(gpu_id=gpu_id)
332
333
device = self._execution_device
334
335
# Turn the prompts into embeddings.
336
inputs = self.tokenizer(
337
[start_prompt, end_prompt],
338
padding="max_length",
339
truncation=True,
340
max_length=self.tokenizer.model_max_length,
341
return_tensors="pt",
342
)
343
inputs.to(device)
344
text_model_output = self.text_encoder(**inputs)
345
346
text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])
347
text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)
348
349
# Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline
350
batch_text_embeds = []
351
batch_last_hidden_state = []
352
353
for interp_val in torch.linspace(0, 1, steps):
354
text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])
355
last_hidden_state = slerp(
356
interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]
357
)
358
batch_text_embeds.append(text_embeds.unsqueeze(0))
359
batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))
360
361
batch_text_embeds = torch.cat(batch_text_embeds)
362
batch_last_hidden_state = torch.cat(batch_last_hidden_state)
363
364
text_model_output = CLIPTextModelOutput(
365
text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state
366
)
367
368
batch_size = text_model_output[0].shape[0]
369
370
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
371
372
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
373
prompt=None,
374
device=device,
375
num_images_per_prompt=1,
376
do_classifier_free_guidance=do_classifier_free_guidance,
377
text_model_output=text_model_output,
378
text_attention_mask=text_attention_mask,
379
)
380
381
# prior
382
383
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
384
prior_timesteps_tensor = self.prior_scheduler.timesteps
385
386
embedding_dim = self.prior.config.embedding_dim
387
388
prior_latents = self.prepare_latents(
389
(batch_size, embedding_dim),
390
prompt_embeds.dtype,
391
device,
392
generator,
393
None,
394
self.prior_scheduler,
395
)
396
397
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
398
# expand the latents if we are doing classifier free guidance
399
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
400
401
predicted_image_embedding = self.prior(
402
latent_model_input,
403
timestep=t,
404
proj_embedding=prompt_embeds,
405
encoder_hidden_states=text_encoder_hidden_states,
406
attention_mask=text_mask,
407
).predicted_image_embedding
408
409
if do_classifier_free_guidance:
410
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
411
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
412
predicted_image_embedding_text - predicted_image_embedding_uncond
413
)
414
415
if i + 1 == prior_timesteps_tensor.shape[0]:
416
prev_timestep = None
417
else:
418
prev_timestep = prior_timesteps_tensor[i + 1]
419
420
prior_latents = self.prior_scheduler.step(
421
predicted_image_embedding,
422
timestep=t,
423
sample=prior_latents,
424
generator=generator,
425
prev_timestep=prev_timestep,
426
).prev_sample
427
428
prior_latents = self.prior.post_process_latents(prior_latents)
429
430
image_embeddings = prior_latents
431
432
# done prior
433
434
# decoder
435
436
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
437
image_embeddings=image_embeddings,
438
prompt_embeds=prompt_embeds,
439
text_encoder_hidden_states=text_encoder_hidden_states,
440
do_classifier_free_guidance=do_classifier_free_guidance,
441
)
442
443
if device.type == "mps":
444
# HACK: MPS: There is a panic when padding bool tensors,
445
# so cast to int tensor for the pad and back to bool afterwards
446
text_mask = text_mask.type(torch.int)
447
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
448
decoder_text_mask = decoder_text_mask.type(torch.bool)
449
else:
450
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
451
452
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
453
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
454
455
num_channels_latents = self.decoder.in_channels
456
height = self.decoder.sample_size
457
width = self.decoder.sample_size
458
459
decoder_latents = self.prepare_latents(
460
(batch_size, num_channels_latents, height, width),
461
text_encoder_hidden_states.dtype,
462
device,
463
generator,
464
None,
465
self.decoder_scheduler,
466
)
467
468
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
469
# expand the latents if we are doing classifier free guidance
470
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
471
472
noise_pred = self.decoder(
473
sample=latent_model_input,
474
timestep=t,
475
encoder_hidden_states=text_encoder_hidden_states,
476
class_labels=additive_clip_time_embeddings,
477
attention_mask=decoder_text_mask,
478
).sample
479
480
if do_classifier_free_guidance:
481
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
482
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
483
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
484
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
485
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
486
487
if i + 1 == decoder_timesteps_tensor.shape[0]:
488
prev_timestep = None
489
else:
490
prev_timestep = decoder_timesteps_tensor[i + 1]
491
492
# compute the previous noisy sample x_t -> x_t-1
493
decoder_latents = self.decoder_scheduler.step(
494
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
495
).prev_sample
496
497
decoder_latents = decoder_latents.clamp(-1, 1)
498
499
image_small = decoder_latents
500
501
# done decoder
502
503
# super res
504
505
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
506
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
507
508
channels = self.super_res_first.in_channels // 2
509
height = self.super_res_first.sample_size
510
width = self.super_res_first.sample_size
511
512
super_res_latents = self.prepare_latents(
513
(batch_size, channels, height, width),
514
image_small.dtype,
515
device,
516
generator,
517
None,
518
self.super_res_scheduler,
519
)
520
521
if device.type == "mps":
522
# MPS does not support many interpolations
523
image_upscaled = F.interpolate(image_small, size=[height, width])
524
else:
525
interpolate_antialias = {}
526
if "antialias" in inspect.signature(F.interpolate).parameters:
527
interpolate_antialias["antialias"] = True
528
529
image_upscaled = F.interpolate(
530
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
531
)
532
533
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
534
# no classifier free guidance
535
536
if i == super_res_timesteps_tensor.shape[0] - 1:
537
unet = self.super_res_last
538
else:
539
unet = self.super_res_first
540
541
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
542
543
noise_pred = unet(
544
sample=latent_model_input,
545
timestep=t,
546
).sample
547
548
if i + 1 == super_res_timesteps_tensor.shape[0]:
549
prev_timestep = None
550
else:
551
prev_timestep = super_res_timesteps_tensor[i + 1]
552
553
# compute the previous noisy sample x_t -> x_t-1
554
super_res_latents = self.super_res_scheduler.step(
555
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
556
).prev_sample
557
558
image = super_res_latents
559
# done super res
560
561
# post processing
562
563
image = image * 0.5 + 0.5
564
image = image.clamp(0, 1)
565
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
566
567
if output_type == "pil":
568
image = self.numpy_to_pil(image)
569
570
if not return_dict:
571
return (image,)
572
573
return ImagePipelineOutput(images=image)
574
575