Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/interpolate_stable_diffusion.py
1448 views
1
import inspect
2
import time
3
from pathlib import Path
4
from typing import Callable, List, Optional, Union
5
6
import numpy as np
7
import torch
8
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
10
from diffusers import DiffusionPipeline
11
from diffusers.configuration_utils import FrozenDict
12
from diffusers.models import AutoencoderKL, UNet2DConditionModel
13
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
15
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
16
from diffusers.utils import deprecate, logging
17
18
19
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
21
22
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
23
"""helper function to spherically interpolate two arrays v1 v2"""
24
25
if not isinstance(v0, np.ndarray):
26
inputs_are_torch = True
27
input_device = v0.device
28
v0 = v0.cpu().numpy()
29
v1 = v1.cpu().numpy()
30
31
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
32
if np.abs(dot) > DOT_THRESHOLD:
33
v2 = (1 - t) * v0 + t * v1
34
else:
35
theta_0 = np.arccos(dot)
36
sin_theta_0 = np.sin(theta_0)
37
theta_t = theta_0 * t
38
sin_theta_t = np.sin(theta_t)
39
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
40
s1 = sin_theta_t / sin_theta_0
41
v2 = s0 * v0 + s1 * v1
42
43
if inputs_are_torch:
44
v2 = torch.from_numpy(v2).to(input_device)
45
46
return v2
47
48
49
class StableDiffusionWalkPipeline(DiffusionPipeline):
50
r"""
51
Pipeline for text-to-image generation using Stable Diffusion.
52
53
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
54
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
55
56
Args:
57
vae ([`AutoencoderKL`]):
58
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
59
text_encoder ([`CLIPTextModel`]):
60
Frozen text-encoder. Stable Diffusion uses the text portion of
61
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
62
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
63
tokenizer (`CLIPTokenizer`):
64
Tokenizer of class
65
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
66
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
67
scheduler ([`SchedulerMixin`]):
68
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
69
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
70
safety_checker ([`StableDiffusionSafetyChecker`]):
71
Classification module that estimates whether generated images could be considered offensive or harmful.
72
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
73
feature_extractor ([`CLIPImageProcessor`]):
74
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
75
"""
76
77
def __init__(
78
self,
79
vae: AutoencoderKL,
80
text_encoder: CLIPTextModel,
81
tokenizer: CLIPTokenizer,
82
unet: UNet2DConditionModel,
83
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
84
safety_checker: StableDiffusionSafetyChecker,
85
feature_extractor: CLIPImageProcessor,
86
):
87
super().__init__()
88
89
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
90
deprecation_message = (
91
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
92
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
93
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
94
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
95
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
96
" file"
97
)
98
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
99
new_config = dict(scheduler.config)
100
new_config["steps_offset"] = 1
101
scheduler._internal_dict = FrozenDict(new_config)
102
103
if safety_checker is None:
104
logger.warning(
105
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
106
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
107
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
108
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
109
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
110
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
111
)
112
113
self.register_modules(
114
vae=vae,
115
text_encoder=text_encoder,
116
tokenizer=tokenizer,
117
unet=unet,
118
scheduler=scheduler,
119
safety_checker=safety_checker,
120
feature_extractor=feature_extractor,
121
)
122
123
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
124
r"""
125
Enable sliced attention computation.
126
127
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
128
in several steps. This is useful to save some memory in exchange for a small speed decrease.
129
130
Args:
131
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
132
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
133
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
134
`attention_head_dim` must be a multiple of `slice_size`.
135
"""
136
if slice_size == "auto":
137
# half the attention head size is usually a good trade-off between
138
# speed and memory
139
slice_size = self.unet.config.attention_head_dim // 2
140
self.unet.set_attention_slice(slice_size)
141
142
def disable_attention_slicing(self):
143
r"""
144
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
145
back to computing attention in one step.
146
"""
147
# set slice_size = `None` to disable `attention slicing`
148
self.enable_attention_slicing(None)
149
150
@torch.no_grad()
151
def __call__(
152
self,
153
prompt: Optional[Union[str, List[str]]] = None,
154
height: int = 512,
155
width: int = 512,
156
num_inference_steps: int = 50,
157
guidance_scale: float = 7.5,
158
negative_prompt: Optional[Union[str, List[str]]] = None,
159
num_images_per_prompt: Optional[int] = 1,
160
eta: float = 0.0,
161
generator: Optional[torch.Generator] = None,
162
latents: Optional[torch.FloatTensor] = None,
163
output_type: Optional[str] = "pil",
164
return_dict: bool = True,
165
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
166
callback_steps: int = 1,
167
text_embeddings: Optional[torch.FloatTensor] = None,
168
**kwargs,
169
):
170
r"""
171
Function invoked when calling the pipeline for generation.
172
173
Args:
174
prompt (`str` or `List[str]`, *optional*, defaults to `None`):
175
The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
176
height (`int`, *optional*, defaults to 512):
177
The height in pixels of the generated image.
178
width (`int`, *optional*, defaults to 512):
179
The width in pixels of the generated image.
180
num_inference_steps (`int`, *optional*, defaults to 50):
181
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
182
expense of slower inference.
183
guidance_scale (`float`, *optional*, defaults to 7.5):
184
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
185
`guidance_scale` is defined as `w` of equation 2. of [Imagen
186
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
187
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
188
usually at the expense of lower image quality.
189
negative_prompt (`str` or `List[str]`, *optional*):
190
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
191
if `guidance_scale` is less than `1`).
192
num_images_per_prompt (`int`, *optional*, defaults to 1):
193
The number of images to generate per prompt.
194
eta (`float`, *optional*, defaults to 0.0):
195
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
196
[`schedulers.DDIMScheduler`], will be ignored for others.
197
generator (`torch.Generator`, *optional*):
198
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
199
deterministic.
200
latents (`torch.FloatTensor`, *optional*):
201
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
202
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
203
tensor will ge generated by sampling using the supplied random `generator`.
204
output_type (`str`, *optional*, defaults to `"pil"`):
205
The output format of the generate image. Choose between
206
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
207
return_dict (`bool`, *optional*, defaults to `True`):
208
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
209
plain tuple.
210
callback (`Callable`, *optional*):
211
A function that will be called every `callback_steps` steps during inference. The function will be
212
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
213
callback_steps (`int`, *optional*, defaults to 1):
214
The frequency at which the `callback` function will be called. If not specified, the callback will be
215
called at every step.
216
text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
217
Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
218
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
219
the supplied `prompt`.
220
221
Returns:
222
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
223
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
224
When returning a tuple, the first element is a list with the generated images, and the second element is a
225
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
226
(nsfw) content, according to the `safety_checker`.
227
"""
228
229
if height % 8 != 0 or width % 8 != 0:
230
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
231
232
if (callback_steps is None) or (
233
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
234
):
235
raise ValueError(
236
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
237
f" {type(callback_steps)}."
238
)
239
240
if text_embeddings is None:
241
if isinstance(prompt, str):
242
batch_size = 1
243
elif isinstance(prompt, list):
244
batch_size = len(prompt)
245
else:
246
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
247
248
# get prompt text embeddings
249
text_inputs = self.tokenizer(
250
prompt,
251
padding="max_length",
252
max_length=self.tokenizer.model_max_length,
253
return_tensors="pt",
254
)
255
text_input_ids = text_inputs.input_ids
256
257
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
258
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
259
print(
260
"The following part of your input was truncated because CLIP can only handle sequences up to"
261
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
262
)
263
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
264
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
265
else:
266
batch_size = text_embeddings.shape[0]
267
268
# duplicate text embeddings for each generation per prompt, using mps friendly method
269
bs_embed, seq_len, _ = text_embeddings.shape
270
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
271
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
272
273
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
274
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
275
# corresponds to doing no classifier free guidance.
276
do_classifier_free_guidance = guidance_scale > 1.0
277
# get unconditional embeddings for classifier free guidance
278
if do_classifier_free_guidance:
279
uncond_tokens: List[str]
280
if negative_prompt is None:
281
uncond_tokens = [""] * batch_size
282
elif type(prompt) is not type(negative_prompt):
283
raise TypeError(
284
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
285
f" {type(prompt)}."
286
)
287
elif isinstance(negative_prompt, str):
288
uncond_tokens = [negative_prompt]
289
elif batch_size != len(negative_prompt):
290
raise ValueError(
291
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
292
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
293
" the batch size of `prompt`."
294
)
295
else:
296
uncond_tokens = negative_prompt
297
298
max_length = self.tokenizer.model_max_length
299
uncond_input = self.tokenizer(
300
uncond_tokens,
301
padding="max_length",
302
max_length=max_length,
303
truncation=True,
304
return_tensors="pt",
305
)
306
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
307
308
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
309
seq_len = uncond_embeddings.shape[1]
310
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
311
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
312
313
# For classifier free guidance, we need to do two forward passes.
314
# Here we concatenate the unconditional and text embeddings into a single batch
315
# to avoid doing two forward passes
316
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
317
318
# get the initial random noise unless the user supplied it
319
320
# Unlike in other pipelines, latents need to be generated in the target device
321
# for 1-to-1 results reproducibility with the CompVis implementation.
322
# However this currently doesn't work in `mps`.
323
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
324
latents_dtype = text_embeddings.dtype
325
if latents is None:
326
if self.device.type == "mps":
327
# randn does not work reproducibly on mps
328
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
329
self.device
330
)
331
else:
332
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
333
else:
334
if latents.shape != latents_shape:
335
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
336
latents = latents.to(self.device)
337
338
# set timesteps
339
self.scheduler.set_timesteps(num_inference_steps)
340
341
# Some schedulers like PNDM have timesteps as arrays
342
# It's more optimized to move all timesteps to correct device beforehand
343
timesteps_tensor = self.scheduler.timesteps.to(self.device)
344
345
# scale the initial noise by the standard deviation required by the scheduler
346
latents = latents * self.scheduler.init_noise_sigma
347
348
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
# and should be between [0, 1]
352
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
353
extra_step_kwargs = {}
354
if accepts_eta:
355
extra_step_kwargs["eta"] = eta
356
357
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
358
# expand the latents if we are doing classifier free guidance
359
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
360
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
361
362
# predict the noise residual
363
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
364
365
# perform guidance
366
if do_classifier_free_guidance:
367
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
368
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
369
370
# compute the previous noisy sample x_t -> x_t-1
371
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
372
373
# call the callback, if provided
374
if callback is not None and i % callback_steps == 0:
375
callback(i, t, latents)
376
377
latents = 1 / 0.18215 * latents
378
image = self.vae.decode(latents).sample
379
380
image = (image / 2 + 0.5).clamp(0, 1)
381
382
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
383
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
384
385
if self.safety_checker is not None:
386
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
387
self.device
388
)
389
image, has_nsfw_concept = self.safety_checker(
390
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
391
)
392
else:
393
has_nsfw_concept = None
394
395
if output_type == "pil":
396
image = self.numpy_to_pil(image)
397
398
if not return_dict:
399
return (image, has_nsfw_concept)
400
401
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
402
403
def embed_text(self, text):
404
"""takes in text and turns it into text embeddings"""
405
text_input = self.tokenizer(
406
text,
407
padding="max_length",
408
max_length=self.tokenizer.model_max_length,
409
truncation=True,
410
return_tensors="pt",
411
)
412
with torch.no_grad():
413
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
414
return embed
415
416
def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
417
"""Takes in random seed and returns corresponding noise vector"""
418
return torch.randn(
419
(1, self.unet.in_channels, height // 8, width // 8),
420
generator=torch.Generator(device=self.device).manual_seed(seed),
421
device=self.device,
422
dtype=dtype,
423
)
424
425
def walk(
426
self,
427
prompts: List[str],
428
seeds: List[int],
429
num_interpolation_steps: Optional[int] = 6,
430
output_dir: Optional[str] = "./dreams",
431
name: Optional[str] = None,
432
batch_size: Optional[int] = 1,
433
height: Optional[int] = 512,
434
width: Optional[int] = 512,
435
guidance_scale: Optional[float] = 7.5,
436
num_inference_steps: Optional[int] = 50,
437
eta: Optional[float] = 0.0,
438
) -> List[str]:
439
"""
440
Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.
441
442
Args:
443
prompts (`List[str]`):
444
List of prompts to generate images for.
445
seeds (`List[int]`):
446
List of seeds corresponding to provided prompts. Must be the same length as prompts.
447
num_interpolation_steps (`int`, *optional*, defaults to 6):
448
Number of interpolation steps to take between prompts.
449
output_dir (`str`, *optional*, defaults to `./dreams`):
450
Directory to save the generated images to.
451
name (`str`, *optional*, defaults to `None`):
452
Subdirectory of `output_dir` to save the generated images to. If `None`, the name will
453
be the current time.
454
batch_size (`int`, *optional*, defaults to 1):
455
Number of images to generate at once.
456
height (`int`, *optional*, defaults to 512):
457
Height of the generated images.
458
width (`int`, *optional*, defaults to 512):
459
Width of the generated images.
460
guidance_scale (`float`, *optional*, defaults to 7.5):
461
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
462
`guidance_scale` is defined as `w` of equation 2. of [Imagen
463
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
464
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
465
usually at the expense of lower image quality.
466
num_inference_steps (`int`, *optional*, defaults to 50):
467
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
468
expense of slower inference.
469
eta (`float`, *optional*, defaults to 0.0):
470
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
471
[`schedulers.DDIMScheduler`], will be ignored for others.
472
473
Returns:
474
`List[str]`: List of paths to the generated images.
475
"""
476
if not len(prompts) == len(seeds):
477
raise ValueError(
478
f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"
479
)
480
481
name = name or time.strftime("%Y%m%d-%H%M%S")
482
save_path = Path(output_dir) / name
483
save_path.mkdir(exist_ok=True, parents=True)
484
485
frame_idx = 0
486
frame_filepaths = []
487
for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):
488
# Embed Text
489
embed_a = self.embed_text(prompt_a)
490
embed_b = self.embed_text(prompt_b)
491
492
# Get Noise
493
noise_dtype = embed_a.dtype
494
noise_a = self.get_noise(seed_a, noise_dtype, height, width)
495
noise_b = self.get_noise(seed_b, noise_dtype, height, width)
496
497
noise_batch, embeds_batch = None, None
498
T = np.linspace(0.0, 1.0, num_interpolation_steps)
499
for i, t in enumerate(T):
500
noise = slerp(float(t), noise_a, noise_b)
501
embed = torch.lerp(embed_a, embed_b, t)
502
503
noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)
504
embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)
505
506
batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
507
if batch_is_ready:
508
outputs = self(
509
latents=noise_batch,
510
text_embeddings=embeds_batch,
511
height=height,
512
width=width,
513
guidance_scale=guidance_scale,
514
eta=eta,
515
num_inference_steps=num_inference_steps,
516
)
517
noise_batch, embeds_batch = None, None
518
519
for image in outputs["images"]:
520
frame_filepath = str(save_path / f"frame_{frame_idx:06d}.png")
521
image.save(frame_filepath)
522
frame_filepaths.append(frame_filepath)
523
frame_idx += 1
524
return frame_filepaths
525
526