Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/sd_text2img_k_diffusion.py
1448 views
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import importlib
16
import warnings
17
from typing import Callable, List, Optional, Union
18
19
import torch
20
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
21
22
from diffusers import DiffusionPipeline, LMSDiscreteScheduler
23
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
24
from diffusers.utils import is_accelerate_available, logging
25
26
27
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
29
30
class ModelWrapper:
31
def __init__(self, model, alphas_cumprod):
32
self.model = model
33
self.alphas_cumprod = alphas_cumprod
34
35
def apply_model(self, *args, **kwargs):
36
if len(args) == 3:
37
encoder_hidden_states = args[-1]
38
args = args[:2]
39
if kwargs.get("cond", None) is not None:
40
encoder_hidden_states = kwargs.pop("cond")
41
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
42
43
44
class StableDiffusionPipeline(DiffusionPipeline):
45
r"""
46
Pipeline for text-to-image generation using Stable Diffusion.
47
48
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
49
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
50
51
Args:
52
vae ([`AutoencoderKL`]):
53
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
54
text_encoder ([`CLIPTextModel`]):
55
Frozen text-encoder. Stable Diffusion uses the text portion of
56
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
57
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
58
tokenizer (`CLIPTokenizer`):
59
Tokenizer of class
60
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
61
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
62
scheduler ([`SchedulerMixin`]):
63
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
64
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
65
safety_checker ([`StableDiffusionSafetyChecker`]):
66
Classification module that estimates whether generated images could be considered offensive or harmful.
67
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
68
feature_extractor ([`CLIPImageProcessor`]):
69
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
70
"""
71
_optional_components = ["safety_checker", "feature_extractor"]
72
73
def __init__(
74
self,
75
vae,
76
text_encoder,
77
tokenizer,
78
unet,
79
scheduler,
80
safety_checker,
81
feature_extractor,
82
):
83
super().__init__()
84
85
if safety_checker is None:
86
logger.warning(
87
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
88
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
89
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
90
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
91
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
92
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
93
)
94
95
# get correct sigmas from LMS
96
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
97
self.register_modules(
98
vae=vae,
99
text_encoder=text_encoder,
100
tokenizer=tokenizer,
101
unet=unet,
102
scheduler=scheduler,
103
safety_checker=safety_checker,
104
feature_extractor=feature_extractor,
105
)
106
107
model = ModelWrapper(unet, scheduler.alphas_cumprod)
108
if scheduler.prediction_type == "v_prediction":
109
self.k_diffusion_model = CompVisVDenoiser(model)
110
else:
111
self.k_diffusion_model = CompVisDenoiser(model)
112
113
def set_sampler(self, scheduler_type: str):
114
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
115
return self.set_scheduler(scheduler_type)
116
117
def set_scheduler(self, scheduler_type: str):
118
library = importlib.import_module("k_diffusion")
119
sampling = getattr(library, "sampling")
120
self.sampler = getattr(sampling, scheduler_type)
121
122
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
123
r"""
124
Enable sliced attention computation.
125
126
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
127
in several steps. This is useful to save some memory in exchange for a small speed decrease.
128
129
Args:
130
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
131
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
132
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
133
`attention_head_dim` must be a multiple of `slice_size`.
134
"""
135
if slice_size == "auto":
136
# half the attention head size is usually a good trade-off between
137
# speed and memory
138
slice_size = self.unet.config.attention_head_dim // 2
139
self.unet.set_attention_slice(slice_size)
140
141
def disable_attention_slicing(self):
142
r"""
143
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
144
back to computing attention in one step.
145
"""
146
# set slice_size = `None` to disable `attention slicing`
147
self.enable_attention_slicing(None)
148
149
def enable_sequential_cpu_offload(self, gpu_id=0):
150
r"""
151
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
152
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
153
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
154
"""
155
if is_accelerate_available():
156
from accelerate import cpu_offload
157
else:
158
raise ImportError("Please install accelerate via `pip install accelerate`")
159
160
device = torch.device(f"cuda:{gpu_id}")
161
162
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
163
if cpu_offloaded_model is not None:
164
cpu_offload(cpu_offloaded_model, device)
165
166
@property
167
def _execution_device(self):
168
r"""
169
Returns the device on which the pipeline's models will be executed. After calling
170
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
171
hooks.
172
"""
173
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
174
return self.device
175
for module in self.unet.modules():
176
if (
177
hasattr(module, "_hf_hook")
178
and hasattr(module._hf_hook, "execution_device")
179
and module._hf_hook.execution_device is not None
180
):
181
return torch.device(module._hf_hook.execution_device)
182
return self.device
183
184
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
185
r"""
186
Encodes the prompt into text encoder hidden states.
187
188
Args:
189
prompt (`str` or `list(int)`):
190
prompt to be encoded
191
device: (`torch.device`):
192
torch device
193
num_images_per_prompt (`int`):
194
number of images that should be generated per prompt
195
do_classifier_free_guidance (`bool`):
196
whether to use classifier free guidance or not
197
negative_prompt (`str` or `List[str]`):
198
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
199
if `guidance_scale` is less than `1`).
200
"""
201
batch_size = len(prompt) if isinstance(prompt, list) else 1
202
203
text_inputs = self.tokenizer(
204
prompt,
205
padding="max_length",
206
max_length=self.tokenizer.model_max_length,
207
truncation=True,
208
return_tensors="pt",
209
)
210
text_input_ids = text_inputs.input_ids
211
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
212
213
if not torch.equal(text_input_ids, untruncated_ids):
214
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
215
logger.warning(
216
"The following part of your input was truncated because CLIP can only handle sequences up to"
217
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
218
)
219
220
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
221
attention_mask = text_inputs.attention_mask.to(device)
222
else:
223
attention_mask = None
224
225
text_embeddings = self.text_encoder(
226
text_input_ids.to(device),
227
attention_mask=attention_mask,
228
)
229
text_embeddings = text_embeddings[0]
230
231
# duplicate text embeddings for each generation per prompt, using mps friendly method
232
bs_embed, seq_len, _ = text_embeddings.shape
233
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
234
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
235
236
# get unconditional embeddings for classifier free guidance
237
if do_classifier_free_guidance:
238
uncond_tokens: List[str]
239
if negative_prompt is None:
240
uncond_tokens = [""] * batch_size
241
elif type(prompt) is not type(negative_prompt):
242
raise TypeError(
243
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
244
f" {type(prompt)}."
245
)
246
elif isinstance(negative_prompt, str):
247
uncond_tokens = [negative_prompt]
248
elif batch_size != len(negative_prompt):
249
raise ValueError(
250
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
251
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
252
" the batch size of `prompt`."
253
)
254
else:
255
uncond_tokens = negative_prompt
256
257
max_length = text_input_ids.shape[-1]
258
uncond_input = self.tokenizer(
259
uncond_tokens,
260
padding="max_length",
261
max_length=max_length,
262
truncation=True,
263
return_tensors="pt",
264
)
265
266
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
267
attention_mask = uncond_input.attention_mask.to(device)
268
else:
269
attention_mask = None
270
271
uncond_embeddings = self.text_encoder(
272
uncond_input.input_ids.to(device),
273
attention_mask=attention_mask,
274
)
275
uncond_embeddings = uncond_embeddings[0]
276
277
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
278
seq_len = uncond_embeddings.shape[1]
279
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
280
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
281
282
# For classifier free guidance, we need to do two forward passes.
283
# Here we concatenate the unconditional and text embeddings into a single batch
284
# to avoid doing two forward passes
285
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
286
287
return text_embeddings
288
289
def run_safety_checker(self, image, device, dtype):
290
if self.safety_checker is not None:
291
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
292
image, has_nsfw_concept = self.safety_checker(
293
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
294
)
295
else:
296
has_nsfw_concept = None
297
return image, has_nsfw_concept
298
299
def decode_latents(self, latents):
300
latents = 1 / 0.18215 * latents
301
image = self.vae.decode(latents).sample
302
image = (image / 2 + 0.5).clamp(0, 1)
303
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
304
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
305
return image
306
307
def check_inputs(self, prompt, height, width, callback_steps):
308
if not isinstance(prompt, str) and not isinstance(prompt, list):
309
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
310
311
if height % 8 != 0 or width % 8 != 0:
312
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
313
314
if (callback_steps is None) or (
315
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
316
):
317
raise ValueError(
318
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
319
f" {type(callback_steps)}."
320
)
321
322
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
323
shape = (batch_size, num_channels_latents, height // 8, width // 8)
324
if latents is None:
325
if device.type == "mps":
326
# randn does not work reproducibly on mps
327
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
328
else:
329
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
330
else:
331
if latents.shape != shape:
332
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
333
latents = latents.to(device)
334
335
# scale the initial noise by the standard deviation required by the scheduler
336
return latents
337
338
@torch.no_grad()
339
def __call__(
340
self,
341
prompt: Union[str, List[str]],
342
height: int = 512,
343
width: int = 512,
344
num_inference_steps: int = 50,
345
guidance_scale: float = 7.5,
346
negative_prompt: Optional[Union[str, List[str]]] = None,
347
num_images_per_prompt: Optional[int] = 1,
348
eta: float = 0.0,
349
generator: Optional[torch.Generator] = None,
350
latents: Optional[torch.FloatTensor] = None,
351
output_type: Optional[str] = "pil",
352
return_dict: bool = True,
353
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
354
callback_steps: int = 1,
355
**kwargs,
356
):
357
r"""
358
Function invoked when calling the pipeline for generation.
359
360
Args:
361
prompt (`str` or `List[str]`):
362
The prompt or prompts to guide the image generation.
363
height (`int`, *optional*, defaults to 512):
364
The height in pixels of the generated image.
365
width (`int`, *optional*, defaults to 512):
366
The width in pixels of the generated image.
367
num_inference_steps (`int`, *optional*, defaults to 50):
368
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
369
expense of slower inference.
370
guidance_scale (`float`, *optional*, defaults to 7.5):
371
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
372
`guidance_scale` is defined as `w` of equation 2. of [Imagen
373
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
374
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
375
usually at the expense of lower image quality.
376
negative_prompt (`str` or `List[str]`, *optional*):
377
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
378
if `guidance_scale` is less than `1`).
379
num_images_per_prompt (`int`, *optional*, defaults to 1):
380
The number of images to generate per prompt.
381
eta (`float`, *optional*, defaults to 0.0):
382
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
383
[`schedulers.DDIMScheduler`], will be ignored for others.
384
generator (`torch.Generator`, *optional*):
385
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
386
deterministic.
387
latents (`torch.FloatTensor`, *optional*):
388
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
389
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
390
tensor will ge generated by sampling using the supplied random `generator`.
391
output_type (`str`, *optional*, defaults to `"pil"`):
392
The output format of the generate image. Choose between
393
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
394
return_dict (`bool`, *optional*, defaults to `True`):
395
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
396
plain tuple.
397
callback (`Callable`, *optional*):
398
A function that will be called every `callback_steps` steps during inference. The function will be
399
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
400
callback_steps (`int`, *optional*, defaults to 1):
401
The frequency at which the `callback` function will be called. If not specified, the callback will be
402
called at every step.
403
404
Returns:
405
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
406
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
407
When returning a tuple, the first element is a list with the generated images, and the second element is a
408
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
409
(nsfw) content, according to the `safety_checker`.
410
"""
411
412
# 1. Check inputs. Raise error if not correct
413
self.check_inputs(prompt, height, width, callback_steps)
414
415
# 2. Define call parameters
416
batch_size = 1 if isinstance(prompt, str) else len(prompt)
417
device = self._execution_device
418
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
419
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
420
# corresponds to doing no classifier free guidance.
421
do_classifier_free_guidance = True
422
if guidance_scale <= 1.0:
423
raise ValueError("has to use guidance_scale")
424
425
# 3. Encode input prompt
426
text_embeddings = self._encode_prompt(
427
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
428
)
429
430
# 4. Prepare timesteps
431
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
432
sigmas = self.scheduler.sigmas
433
sigmas = sigmas.to(text_embeddings.dtype)
434
435
# 5. Prepare latent variables
436
num_channels_latents = self.unet.in_channels
437
latents = self.prepare_latents(
438
batch_size * num_images_per_prompt,
439
num_channels_latents,
440
height,
441
width,
442
text_embeddings.dtype,
443
device,
444
generator,
445
latents,
446
)
447
latents = latents * sigmas[0]
448
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
449
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
450
451
def model_fn(x, t):
452
latent_model_input = torch.cat([x] * 2)
453
454
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
455
456
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
457
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
458
return noise_pred
459
460
latents = self.sampler(model_fn, latents, sigmas)
461
462
# 8. Post-processing
463
image = self.decode_latents(latents)
464
465
# 9. Run safety checker
466
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
467
468
# 10. Convert to PIL
469
if output_type == "pil":
470
image = self.numpy_to_pil(image)
471
472
if not return_dict:
473
return (image, has_nsfw_concept)
474
475
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
476
477