Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/composable_stable_diffusion.py
1448 views
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import inspect
16
from typing import Callable, List, Optional, Union
17
18
import torch
19
from packaging import version
20
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21
22
from diffusers import DiffusionPipeline
23
from diffusers.configuration_utils import FrozenDict
24
from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
from diffusers.schedulers import (
26
DDIMScheduler,
27
DPMSolverMultistepScheduler,
28
EulerAncestralDiscreteScheduler,
29
EulerDiscreteScheduler,
30
LMSDiscreteScheduler,
31
PNDMScheduler,
32
)
33
from diffusers.utils import is_accelerate_available
34
35
from ...utils import deprecate, logging
36
from . import StableDiffusionPipelineOutput
37
from .safety_checker import StableDiffusionSafetyChecker
38
39
40
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
42
43
class ComposableStableDiffusionPipeline(DiffusionPipeline):
44
r"""
45
Pipeline for text-to-image generation using Stable Diffusion.
46
47
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
48
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
49
50
Args:
51
vae ([`AutoencoderKL`]):
52
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
53
text_encoder ([`CLIPTextModel`]):
54
Frozen text-encoder. Stable Diffusion uses the text portion of
55
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
56
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
57
tokenizer (`CLIPTokenizer`):
58
Tokenizer of class
59
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
60
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
61
scheduler ([`SchedulerMixin`]):
62
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
63
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
64
safety_checker ([`StableDiffusionSafetyChecker`]):
65
Classification module that estimates whether generated images could be considered offensive or harmful.
66
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
67
feature_extractor ([`CLIPImageProcessor`]):
68
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
69
"""
70
_optional_components = ["safety_checker", "feature_extractor"]
71
72
def __init__(
73
self,
74
vae: AutoencoderKL,
75
text_encoder: CLIPTextModel,
76
tokenizer: CLIPTokenizer,
77
unet: UNet2DConditionModel,
78
scheduler: Union[
79
DDIMScheduler,
80
PNDMScheduler,
81
LMSDiscreteScheduler,
82
EulerDiscreteScheduler,
83
EulerAncestralDiscreteScheduler,
84
DPMSolverMultistepScheduler,
85
],
86
safety_checker: StableDiffusionSafetyChecker,
87
feature_extractor: CLIPImageProcessor,
88
requires_safety_checker: bool = True,
89
):
90
super().__init__()
91
92
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
93
deprecation_message = (
94
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
95
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
96
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
97
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
98
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
99
" file"
100
)
101
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
102
new_config = dict(scheduler.config)
103
new_config["steps_offset"] = 1
104
scheduler._internal_dict = FrozenDict(new_config)
105
106
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
107
deprecation_message = (
108
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
109
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
110
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
111
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
112
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
113
)
114
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
115
new_config = dict(scheduler.config)
116
new_config["clip_sample"] = False
117
scheduler._internal_dict = FrozenDict(new_config)
118
119
if safety_checker is None and requires_safety_checker:
120
logger.warning(
121
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
122
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
123
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
124
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
125
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
126
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
127
)
128
129
if safety_checker is not None and feature_extractor is None:
130
raise ValueError(
131
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
132
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133
)
134
135
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
136
version.parse(unet.config._diffusers_version).base_version
137
) < version.parse("0.9.0.dev0")
138
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
139
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
140
deprecation_message = (
141
"The configuration file of the unet has set the default `sample_size` to smaller than"
142
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
143
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
144
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
145
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
146
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
147
" in the config might lead to incorrect results in future versions. If you have downloaded this"
148
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
149
" the `unet/config.json` file"
150
)
151
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
152
new_config = dict(unet.config)
153
new_config["sample_size"] = 64
154
unet._internal_dict = FrozenDict(new_config)
155
156
self.register_modules(
157
vae=vae,
158
text_encoder=text_encoder,
159
tokenizer=tokenizer,
160
unet=unet,
161
scheduler=scheduler,
162
safety_checker=safety_checker,
163
feature_extractor=feature_extractor,
164
)
165
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
166
self.register_to_config(requires_safety_checker=requires_safety_checker)
167
168
def enable_vae_slicing(self):
169
r"""
170
Enable sliced VAE decoding.
171
172
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
173
steps. This is useful to save some memory and allow larger batch sizes.
174
"""
175
self.vae.enable_slicing()
176
177
def disable_vae_slicing(self):
178
r"""
179
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
180
computing decoding in one step.
181
"""
182
self.vae.disable_slicing()
183
184
def enable_sequential_cpu_offload(self, gpu_id=0):
185
r"""
186
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
187
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
188
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
189
"""
190
if is_accelerate_available():
191
from accelerate import cpu_offload
192
else:
193
raise ImportError("Please install accelerate via `pip install accelerate`")
194
195
device = torch.device(f"cuda:{gpu_id}")
196
197
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
if cpu_offloaded_model is not None:
199
cpu_offload(cpu_offloaded_model, device)
200
201
if self.safety_checker is not None:
202
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
203
# fix by only offloading self.safety_checker for now
204
cpu_offload(self.safety_checker.vision_model, device)
205
206
@property
207
def _execution_device(self):
208
r"""
209
Returns the device on which the pipeline's models will be executed. After calling
210
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
211
hooks.
212
"""
213
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
214
return self.device
215
for module in self.unet.modules():
216
if (
217
hasattr(module, "_hf_hook")
218
and hasattr(module._hf_hook, "execution_device")
219
and module._hf_hook.execution_device is not None
220
):
221
return torch.device(module._hf_hook.execution_device)
222
return self.device
223
224
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
225
r"""
226
Encodes the prompt into text encoder hidden states.
227
228
Args:
229
prompt (`str` or `list(int)`):
230
prompt to be encoded
231
device: (`torch.device`):
232
torch device
233
num_images_per_prompt (`int`):
234
number of images that should be generated per prompt
235
do_classifier_free_guidance (`bool`):
236
whether to use classifier free guidance or not
237
negative_prompt (`str` or `List[str]`):
238
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
239
if `guidance_scale` is less than `1`).
240
"""
241
batch_size = len(prompt) if isinstance(prompt, list) else 1
242
243
text_inputs = self.tokenizer(
244
prompt,
245
padding="max_length",
246
max_length=self.tokenizer.model_max_length,
247
truncation=True,
248
return_tensors="pt",
249
)
250
text_input_ids = text_inputs.input_ids
251
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252
253
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
254
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
255
logger.warning(
256
"The following part of your input was truncated because CLIP can only handle sequences up to"
257
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
258
)
259
260
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
261
attention_mask = text_inputs.attention_mask.to(device)
262
else:
263
attention_mask = None
264
265
text_embeddings = self.text_encoder(
266
text_input_ids.to(device),
267
attention_mask=attention_mask,
268
)
269
text_embeddings = text_embeddings[0]
270
271
# duplicate text embeddings for each generation per prompt, using mps friendly method
272
bs_embed, seq_len, _ = text_embeddings.shape
273
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
274
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
275
276
# get unconditional embeddings for classifier free guidance
277
if do_classifier_free_guidance:
278
uncond_tokens: List[str]
279
if negative_prompt is None:
280
uncond_tokens = [""] * batch_size
281
elif type(prompt) is not type(negative_prompt):
282
raise TypeError(
283
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
f" {type(prompt)}."
285
)
286
elif isinstance(negative_prompt, str):
287
uncond_tokens = [negative_prompt]
288
elif batch_size != len(negative_prompt):
289
raise ValueError(
290
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292
" the batch size of `prompt`."
293
)
294
else:
295
uncond_tokens = negative_prompt
296
297
max_length = text_input_ids.shape[-1]
298
uncond_input = self.tokenizer(
299
uncond_tokens,
300
padding="max_length",
301
max_length=max_length,
302
truncation=True,
303
return_tensors="pt",
304
)
305
306
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
307
attention_mask = uncond_input.attention_mask.to(device)
308
else:
309
attention_mask = None
310
311
uncond_embeddings = self.text_encoder(
312
uncond_input.input_ids.to(device),
313
attention_mask=attention_mask,
314
)
315
uncond_embeddings = uncond_embeddings[0]
316
317
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
318
seq_len = uncond_embeddings.shape[1]
319
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
320
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
321
322
# For classifier free guidance, we need to do two forward passes.
323
# Here we concatenate the unconditional and text embeddings into a single batch
324
# to avoid doing two forward passes
325
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
326
327
return text_embeddings
328
329
def run_safety_checker(self, image, device, dtype):
330
if self.safety_checker is not None:
331
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
332
image, has_nsfw_concept = self.safety_checker(
333
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
334
)
335
else:
336
has_nsfw_concept = None
337
return image, has_nsfw_concept
338
339
def decode_latents(self, latents):
340
latents = 1 / 0.18215 * latents
341
image = self.vae.decode(latents).sample
342
image = (image / 2 + 0.5).clamp(0, 1)
343
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
return image
346
347
def prepare_extra_step_kwargs(self, generator, eta):
348
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
# and should be between [0, 1]
352
353
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
354
extra_step_kwargs = {}
355
if accepts_eta:
356
extra_step_kwargs["eta"] = eta
357
358
# check if the scheduler accepts generator
359
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
360
if accepts_generator:
361
extra_step_kwargs["generator"] = generator
362
return extra_step_kwargs
363
364
def check_inputs(self, prompt, height, width, callback_steps):
365
if not isinstance(prompt, str) and not isinstance(prompt, list):
366
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
367
368
if height % 8 != 0 or width % 8 != 0:
369
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
370
371
if (callback_steps is None) or (
372
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
373
):
374
raise ValueError(
375
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
376
f" {type(callback_steps)}."
377
)
378
379
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
380
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
381
if latents is None:
382
if device.type == "mps":
383
# randn does not work reproducibly on mps
384
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
385
else:
386
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
387
else:
388
if latents.shape != shape:
389
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
390
latents = latents.to(device)
391
392
# scale the initial noise by the standard deviation required by the scheduler
393
latents = latents * self.scheduler.init_noise_sigma
394
return latents
395
396
@torch.no_grad()
397
def __call__(
398
self,
399
prompt: Union[str, List[str]],
400
height: Optional[int] = None,
401
width: Optional[int] = None,
402
num_inference_steps: int = 50,
403
guidance_scale: float = 7.5,
404
negative_prompt: Optional[Union[str, List[str]]] = None,
405
num_images_per_prompt: Optional[int] = 1,
406
eta: float = 0.0,
407
generator: Optional[torch.Generator] = None,
408
latents: Optional[torch.FloatTensor] = None,
409
output_type: Optional[str] = "pil",
410
return_dict: bool = True,
411
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
412
callback_steps: int = 1,
413
weights: Optional[str] = "",
414
):
415
r"""
416
Function invoked when calling the pipeline for generation.
417
418
Args:
419
prompt (`str` or `List[str]`):
420
The prompt or prompts to guide the image generation.
421
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
422
The height in pixels of the generated image.
423
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
424
The width in pixels of the generated image.
425
num_inference_steps (`int`, *optional*, defaults to 50):
426
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
427
expense of slower inference.
428
guidance_scale (`float`, *optional*, defaults to 7.5):
429
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
430
`guidance_scale` is defined as `w` of equation 2. of [Imagen
431
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
432
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
433
usually at the expense of lower image quality.
434
negative_prompt (`str` or `List[str]`, *optional*):
435
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
436
if `guidance_scale` is less than `1`).
437
num_images_per_prompt (`int`, *optional*, defaults to 1):
438
The number of images to generate per prompt.
439
eta (`float`, *optional*, defaults to 0.0):
440
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441
[`schedulers.DDIMScheduler`], will be ignored for others.
442
generator (`torch.Generator`, *optional*):
443
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
444
deterministic.
445
latents (`torch.FloatTensor`, *optional*):
446
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
447
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
448
tensor will ge generated by sampling using the supplied random `generator`.
449
output_type (`str`, *optional*, defaults to `"pil"`):
450
The output format of the generate image. Choose between
451
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
452
return_dict (`bool`, *optional*, defaults to `True`):
453
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
454
plain tuple.
455
callback (`Callable`, *optional*):
456
A function that will be called every `callback_steps` steps during inference. The function will be
457
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
458
callback_steps (`int`, *optional*, defaults to 1):
459
The frequency at which the `callback` function will be called. If not specified, the callback will be
460
called at every step.
461
462
Returns:
463
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
464
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
465
When returning a tuple, the first element is a list with the generated images, and the second element is a
466
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
467
(nsfw) content, according to the `safety_checker`.
468
"""
469
# 0. Default height and width to unet
470
height = height or self.unet.config.sample_size * self.vae_scale_factor
471
width = width or self.unet.config.sample_size * self.vae_scale_factor
472
473
# 1. Check inputs. Raise error if not correct
474
self.check_inputs(prompt, height, width, callback_steps)
475
476
# 2. Define call parameters
477
batch_size = 1 if isinstance(prompt, str) else len(prompt)
478
device = self._execution_device
479
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
480
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
481
# corresponds to doing no classifier free guidance.
482
do_classifier_free_guidance = guidance_scale > 1.0
483
484
if "|" in prompt:
485
prompt = [x.strip() for x in prompt.split("|")]
486
print(f"composing {prompt}...")
487
488
if not weights:
489
# specify weights for prompts (excluding the unconditional score)
490
print("using equal positive weights (conjunction) for all prompts...")
491
weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
492
else:
493
# set prompt weight for each
494
num_prompts = len(prompt) if isinstance(prompt, list) else 1
495
weights = [float(w.strip()) for w in weights.split("|")]
496
# guidance scale as the default
497
if len(weights) < num_prompts:
498
weights.append(guidance_scale)
499
else:
500
weights = weights[:num_prompts]
501
assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"
502
weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)
503
else:
504
weights = guidance_scale
505
506
# 3. Encode input prompt
507
text_embeddings = self._encode_prompt(
508
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
509
)
510
511
# 4. Prepare timesteps
512
self.scheduler.set_timesteps(num_inference_steps, device=device)
513
timesteps = self.scheduler.timesteps
514
515
# 5. Prepare latent variables
516
num_channels_latents = self.unet.in_channels
517
latents = self.prepare_latents(
518
batch_size * num_images_per_prompt,
519
num_channels_latents,
520
height,
521
width,
522
text_embeddings.dtype,
523
device,
524
generator,
525
latents,
526
)
527
528
# composable diffusion
529
if isinstance(prompt, list) and batch_size == 1:
530
# remove extra unconditional embedding
531
# N = one unconditional embed + conditional embeds
532
text_embeddings = text_embeddings[len(prompt) - 1 :]
533
534
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
535
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
536
537
# 7. Denoising loop
538
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
539
with self.progress_bar(total=num_inference_steps) as progress_bar:
540
for i, t in enumerate(timesteps):
541
# expand the latents if we are doing classifier free guidance
542
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
543
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
544
545
# predict the noise residual
546
noise_pred = []
547
for j in range(text_embeddings.shape[0]):
548
noise_pred.append(
549
self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample
550
)
551
noise_pred = torch.cat(noise_pred, dim=0)
552
553
# perform guidance
554
if do_classifier_free_guidance:
555
noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]
556
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
557
dim=0, keepdims=True
558
)
559
560
# compute the previous noisy sample x_t -> x_t-1
561
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
562
563
# call the callback, if provided
564
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
565
progress_bar.update()
566
if callback is not None and i % callback_steps == 0:
567
callback(i, t, latents)
568
569
# 8. Post-processing
570
image = self.decode_latents(latents)
571
572
# 9. Run safety checker
573
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
574
575
# 10. Convert to PIL
576
if output_type == "pil":
577
image = self.numpy_to_pil(image)
578
579
if not return_dict:
580
return (image, has_nsfw_concept)
581
582
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
583
584