Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/unclip_image_interpolation.py
1448 views
1
import inspect
2
from typing import List, Optional, Union
3
4
import PIL
5
import torch
6
from torch.nn import functional as F
7
from transformers import (
8
CLIPImageProcessor,
9
CLIPTextModelWithProjection,
10
CLIPTokenizer,
11
CLIPVisionModelWithProjection,
12
)
13
14
from diffusers import (
15
DiffusionPipeline,
16
ImagePipelineOutput,
17
UnCLIPScheduler,
18
UNet2DConditionModel,
19
UNet2DModel,
20
)
21
from diffusers.pipelines.unclip import UnCLIPTextProjModel
22
from diffusers.utils import is_accelerate_available, logging, randn_tensor
23
24
25
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
27
28
def slerp(val, low, high):
29
"""
30
Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
31
"""
32
low_norm = low / torch.norm(low)
33
high_norm = high / torch.norm(high)
34
omega = torch.acos((low_norm * high_norm))
35
so = torch.sin(omega)
36
res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
37
return res
38
39
40
class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
41
"""
42
Pipeline to generate variations from an input image using unCLIP
43
44
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
45
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
46
47
Args:
48
text_encoder ([`CLIPTextModelWithProjection`]):
49
Frozen text-encoder.
50
tokenizer (`CLIPTokenizer`):
51
Tokenizer of class
52
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
53
feature_extractor ([`CLIPImageProcessor`]):
54
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
55
image_encoder ([`CLIPVisionModelWithProjection`]):
56
Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
57
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
58
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
59
text_proj ([`UnCLIPTextProjModel`]):
60
Utility class to prepare and combine the embeddings before they are passed to the decoder.
61
decoder ([`UNet2DConditionModel`]):
62
The decoder to invert the image embedding into an image.
63
super_res_first ([`UNet2DModel`]):
64
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
65
super_res_last ([`UNet2DModel`]):
66
Super resolution unet. Used in the last step of the super resolution diffusion process.
67
decoder_scheduler ([`UnCLIPScheduler`]):
68
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
69
super_res_scheduler ([`UnCLIPScheduler`]):
70
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
71
72
"""
73
74
decoder: UNet2DConditionModel
75
text_proj: UnCLIPTextProjModel
76
text_encoder: CLIPTextModelWithProjection
77
tokenizer: CLIPTokenizer
78
feature_extractor: CLIPImageProcessor
79
image_encoder: CLIPVisionModelWithProjection
80
super_res_first: UNet2DModel
81
super_res_last: UNet2DModel
82
83
decoder_scheduler: UnCLIPScheduler
84
super_res_scheduler: UnCLIPScheduler
85
86
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
87
def __init__(
88
self,
89
decoder: UNet2DConditionModel,
90
text_encoder: CLIPTextModelWithProjection,
91
tokenizer: CLIPTokenizer,
92
text_proj: UnCLIPTextProjModel,
93
feature_extractor: CLIPImageProcessor,
94
image_encoder: CLIPVisionModelWithProjection,
95
super_res_first: UNet2DModel,
96
super_res_last: UNet2DModel,
97
decoder_scheduler: UnCLIPScheduler,
98
super_res_scheduler: UnCLIPScheduler,
99
):
100
super().__init__()
101
102
self.register_modules(
103
decoder=decoder,
104
text_encoder=text_encoder,
105
tokenizer=tokenizer,
106
text_proj=text_proj,
107
feature_extractor=feature_extractor,
108
image_encoder=image_encoder,
109
super_res_first=super_res_first,
110
super_res_last=super_res_last,
111
decoder_scheduler=decoder_scheduler,
112
super_res_scheduler=super_res_scheduler,
113
)
114
115
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
116
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
117
if latents is None:
118
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
119
else:
120
if latents.shape != shape:
121
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
122
latents = latents.to(device)
123
124
latents = latents * scheduler.init_noise_sigma
125
return latents
126
127
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
128
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
129
batch_size = len(prompt) if isinstance(prompt, list) else 1
130
131
# get prompt text embeddings
132
text_inputs = self.tokenizer(
133
prompt,
134
padding="max_length",
135
max_length=self.tokenizer.model_max_length,
136
return_tensors="pt",
137
)
138
text_input_ids = text_inputs.input_ids
139
text_mask = text_inputs.attention_mask.bool().to(device)
140
text_encoder_output = self.text_encoder(text_input_ids.to(device))
141
142
prompt_embeds = text_encoder_output.text_embeds
143
text_encoder_hidden_states = text_encoder_output.last_hidden_state
144
145
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
146
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
147
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
148
149
if do_classifier_free_guidance:
150
uncond_tokens = [""] * batch_size
151
152
max_length = text_input_ids.shape[-1]
153
uncond_input = self.tokenizer(
154
uncond_tokens,
155
padding="max_length",
156
max_length=max_length,
157
truncation=True,
158
return_tensors="pt",
159
)
160
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
161
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
162
163
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
164
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
165
166
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
167
168
seq_len = negative_prompt_embeds.shape[1]
169
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
170
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
171
172
seq_len = uncond_text_encoder_hidden_states.shape[1]
173
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
174
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
175
batch_size * num_images_per_prompt, seq_len, -1
176
)
177
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
178
179
# done duplicates
180
181
# For classifier free guidance, we need to do two forward passes.
182
# Here we concatenate the unconditional and text embeddings into a single batch
183
# to avoid doing two forward passes
184
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
185
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
186
187
text_mask = torch.cat([uncond_text_mask, text_mask])
188
189
return prompt_embeds, text_encoder_hidden_states, text_mask
190
191
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
192
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
193
dtype = next(self.image_encoder.parameters()).dtype
194
195
if image_embeddings is None:
196
if not isinstance(image, torch.Tensor):
197
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
198
199
image = image.to(device=device, dtype=dtype)
200
image_embeddings = self.image_encoder(image).image_embeds
201
202
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
203
204
return image_embeddings
205
206
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload
207
def enable_sequential_cpu_offload(self, gpu_id=0):
208
r"""
209
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
210
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
211
when their specific submodule has its `forward` method called.
212
"""
213
if is_accelerate_available():
214
from accelerate import cpu_offload
215
else:
216
raise ImportError("Please install accelerate via `pip install accelerate`")
217
218
device = torch.device(f"cuda:{gpu_id}")
219
220
models = [
221
self.decoder,
222
self.text_proj,
223
self.text_encoder,
224
self.super_res_first,
225
self.super_res_last,
226
]
227
for cpu_offloaded_model in models:
228
if cpu_offloaded_model is not None:
229
cpu_offload(cpu_offloaded_model, device)
230
231
@property
232
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
233
def _execution_device(self):
234
r"""
235
Returns the device on which the pipeline's models will be executed. After calling
236
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
237
hooks.
238
"""
239
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
240
return self.device
241
for module in self.decoder.modules():
242
if (
243
hasattr(module, "_hf_hook")
244
and hasattr(module._hf_hook, "execution_device")
245
and module._hf_hook.execution_device is not None
246
):
247
return torch.device(module._hf_hook.execution_device)
248
return self.device
249
250
@torch.no_grad()
251
def __call__(
252
self,
253
image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
254
steps: int = 5,
255
decoder_num_inference_steps: int = 25,
256
super_res_num_inference_steps: int = 7,
257
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
258
image_embeddings: Optional[torch.Tensor] = None,
259
decoder_latents: Optional[torch.FloatTensor] = None,
260
super_res_latents: Optional[torch.FloatTensor] = None,
261
decoder_guidance_scale: float = 8.0,
262
output_type: Optional[str] = "pil",
263
return_dict: bool = True,
264
):
265
"""
266
Function invoked when calling the pipeline for generation.
267
268
Args:
269
image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
270
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
271
configuration of
272
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
273
`CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
274
steps (`int`, *optional*, defaults to 5):
275
The number of interpolation images to generate.
276
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
277
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
278
image at the expense of slower inference.
279
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
280
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
281
quality image at the expense of slower inference.
282
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
283
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
284
to make generation deterministic.
285
image_embeddings (`torch.Tensor`, *optional*):
286
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
287
can be passed for tasks like image interpolations. `image` can the be left to `None`.
288
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
289
Pre-generated noisy latents to be used as inputs for the decoder.
290
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
291
Pre-generated noisy latents to be used as inputs for the decoder.
292
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
293
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
294
`guidance_scale` is defined as `w` of equation 2. of [Imagen
295
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
296
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
297
usually at the expense of lower image quality.
298
output_type (`str`, *optional*, defaults to `"pil"`):
299
The output format of the generated image. Choose between
300
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
301
return_dict (`bool`, *optional*, defaults to `True`):
302
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
303
"""
304
305
batch_size = steps
306
307
device = self._execution_device
308
309
if isinstance(image, List):
310
if len(image) != 2:
311
raise AssertionError(
312
f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
313
)
314
elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
315
raise AssertionError(
316
f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
317
)
318
elif isinstance(image, torch.FloatTensor):
319
if image.shape[0] != 2:
320
raise AssertionError(
321
f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
322
)
323
elif isinstance(image_embeddings, torch.Tensor):
324
if image_embeddings.shape[0] != 2:
325
raise AssertionError(
326
f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
327
)
328
else:
329
raise AssertionError(
330
f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
331
)
332
333
original_image_embeddings = self._encode_image(
334
image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
335
)
336
337
image_embeddings = []
338
339
for interp_step in torch.linspace(0, 1, steps):
340
temp_image_embeddings = slerp(
341
interp_step, original_image_embeddings[0], original_image_embeddings[1]
342
).unsqueeze(0)
343
image_embeddings.append(temp_image_embeddings)
344
345
image_embeddings = torch.cat(image_embeddings).to(device)
346
347
do_classifier_free_guidance = decoder_guidance_scale > 1.0
348
349
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
350
prompt=["" for i in range(steps)],
351
device=device,
352
num_images_per_prompt=1,
353
do_classifier_free_guidance=do_classifier_free_guidance,
354
)
355
356
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
357
image_embeddings=image_embeddings,
358
prompt_embeds=prompt_embeds,
359
text_encoder_hidden_states=text_encoder_hidden_states,
360
do_classifier_free_guidance=do_classifier_free_guidance,
361
)
362
363
if device.type == "mps":
364
# HACK: MPS: There is a panic when padding bool tensors,
365
# so cast to int tensor for the pad and back to bool afterwards
366
text_mask = text_mask.type(torch.int)
367
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
368
decoder_text_mask = decoder_text_mask.type(torch.bool)
369
else:
370
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
371
372
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
373
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
374
375
num_channels_latents = self.decoder.in_channels
376
height = self.decoder.sample_size
377
width = self.decoder.sample_size
378
379
decoder_latents = self.prepare_latents(
380
(batch_size, num_channels_latents, height, width),
381
text_encoder_hidden_states.dtype,
382
device,
383
generator,
384
decoder_latents,
385
self.decoder_scheduler,
386
)
387
388
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
389
# expand the latents if we are doing classifier free guidance
390
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
391
392
noise_pred = self.decoder(
393
sample=latent_model_input,
394
timestep=t,
395
encoder_hidden_states=text_encoder_hidden_states,
396
class_labels=additive_clip_time_embeddings,
397
attention_mask=decoder_text_mask,
398
).sample
399
400
if do_classifier_free_guidance:
401
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
402
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
403
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
404
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
405
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
406
407
if i + 1 == decoder_timesteps_tensor.shape[0]:
408
prev_timestep = None
409
else:
410
prev_timestep = decoder_timesteps_tensor[i + 1]
411
412
# compute the previous noisy sample x_t -> x_t-1
413
decoder_latents = self.decoder_scheduler.step(
414
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
415
).prev_sample
416
417
decoder_latents = decoder_latents.clamp(-1, 1)
418
419
image_small = decoder_latents
420
421
# done decoder
422
423
# super res
424
425
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
426
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
427
428
channels = self.super_res_first.in_channels // 2
429
height = self.super_res_first.sample_size
430
width = self.super_res_first.sample_size
431
432
super_res_latents = self.prepare_latents(
433
(batch_size, channels, height, width),
434
image_small.dtype,
435
device,
436
generator,
437
super_res_latents,
438
self.super_res_scheduler,
439
)
440
441
if device.type == "mps":
442
# MPS does not support many interpolations
443
image_upscaled = F.interpolate(image_small, size=[height, width])
444
else:
445
interpolate_antialias = {}
446
if "antialias" in inspect.signature(F.interpolate).parameters:
447
interpolate_antialias["antialias"] = True
448
449
image_upscaled = F.interpolate(
450
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
451
)
452
453
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
454
# no classifier free guidance
455
456
if i == super_res_timesteps_tensor.shape[0] - 1:
457
unet = self.super_res_last
458
else:
459
unet = self.super_res_first
460
461
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
462
463
noise_pred = unet(
464
sample=latent_model_input,
465
timestep=t,
466
).sample
467
468
if i + 1 == super_res_timesteps_tensor.shape[0]:
469
prev_timestep = None
470
else:
471
prev_timestep = super_res_timesteps_tensor[i + 1]
472
473
# compute the previous noisy sample x_t -> x_t-1
474
super_res_latents = self.super_res_scheduler.step(
475
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
476
).prev_sample
477
478
image = super_res_latents
479
# done super res
480
481
# post processing
482
483
image = image * 0.5 + 0.5
484
image = image.clamp(0, 1)
485
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
486
487
if output_type == "pil":
488
image = self.numpy_to_pil(image)
489
490
if not return_dict:
491
return (image,)
492
493
return ImagePipelineOutput(images=image)
494
495