Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/seed_resize_stable_diffusion.py
1448 views
1
"""
2
modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
3
"""
4
import inspect
5
from typing import Callable, List, Optional, Union
6
7
import torch
8
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
10
from diffusers import DiffusionPipeline
11
from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
from diffusers.utils import logging
16
17
18
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
20
21
class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
22
r"""
23
Pipeline for text-to-image generation using Stable Diffusion.
24
25
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
26
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
27
28
Args:
29
vae ([`AutoencoderKL`]):
30
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
31
text_encoder ([`CLIPTextModel`]):
32
Frozen text-encoder. Stable Diffusion uses the text portion of
33
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
34
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
35
tokenizer (`CLIPTokenizer`):
36
Tokenizer of class
37
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
38
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
39
scheduler ([`SchedulerMixin`]):
40
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
41
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
42
safety_checker ([`StableDiffusionSafetyChecker`]):
43
Classification module that estimates whether generated images could be considered offensive or harmful.
44
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
45
feature_extractor ([`CLIPImageProcessor`]):
46
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
47
"""
48
49
def __init__(
50
self,
51
vae: AutoencoderKL,
52
text_encoder: CLIPTextModel,
53
tokenizer: CLIPTokenizer,
54
unet: UNet2DConditionModel,
55
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
56
safety_checker: StableDiffusionSafetyChecker,
57
feature_extractor: CLIPImageProcessor,
58
):
59
super().__init__()
60
self.register_modules(
61
vae=vae,
62
text_encoder=text_encoder,
63
tokenizer=tokenizer,
64
unet=unet,
65
scheduler=scheduler,
66
safety_checker=safety_checker,
67
feature_extractor=feature_extractor,
68
)
69
70
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
71
r"""
72
Enable sliced attention computation.
73
74
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
75
in several steps. This is useful to save some memory in exchange for a small speed decrease.
76
77
Args:
78
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
79
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
80
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
81
`attention_head_dim` must be a multiple of `slice_size`.
82
"""
83
if slice_size == "auto":
84
# half the attention head size is usually a good trade-off between
85
# speed and memory
86
slice_size = self.unet.config.attention_head_dim // 2
87
self.unet.set_attention_slice(slice_size)
88
89
def disable_attention_slicing(self):
90
r"""
91
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
92
back to computing attention in one step.
93
"""
94
# set slice_size = `None` to disable `attention slicing`
95
self.enable_attention_slicing(None)
96
97
@torch.no_grad()
98
def __call__(
99
self,
100
prompt: Union[str, List[str]],
101
height: int = 512,
102
width: int = 512,
103
num_inference_steps: int = 50,
104
guidance_scale: float = 7.5,
105
negative_prompt: Optional[Union[str, List[str]]] = None,
106
num_images_per_prompt: Optional[int] = 1,
107
eta: float = 0.0,
108
generator: Optional[torch.Generator] = None,
109
latents: Optional[torch.FloatTensor] = None,
110
output_type: Optional[str] = "pil",
111
return_dict: bool = True,
112
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
113
callback_steps: int = 1,
114
text_embeddings: Optional[torch.FloatTensor] = None,
115
**kwargs,
116
):
117
r"""
118
Function invoked when calling the pipeline for generation.
119
120
Args:
121
prompt (`str` or `List[str]`):
122
The prompt or prompts to guide the image generation.
123
height (`int`, *optional*, defaults to 512):
124
The height in pixels of the generated image.
125
width (`int`, *optional*, defaults to 512):
126
The width in pixels of the generated image.
127
num_inference_steps (`int`, *optional*, defaults to 50):
128
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
129
expense of slower inference.
130
guidance_scale (`float`, *optional*, defaults to 7.5):
131
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
132
`guidance_scale` is defined as `w` of equation 2. of [Imagen
133
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
134
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
135
usually at the expense of lower image quality.
136
negative_prompt (`str` or `List[str]`, *optional*):
137
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
138
if `guidance_scale` is less than `1`).
139
num_images_per_prompt (`int`, *optional*, defaults to 1):
140
The number of images to generate per prompt.
141
eta (`float`, *optional*, defaults to 0.0):
142
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
143
[`schedulers.DDIMScheduler`], will be ignored for others.
144
generator (`torch.Generator`, *optional*):
145
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
146
deterministic.
147
latents (`torch.FloatTensor`, *optional*):
148
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
149
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
150
tensor will ge generated by sampling using the supplied random `generator`.
151
output_type (`str`, *optional*, defaults to `"pil"`):
152
The output format of the generate image. Choose between
153
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
154
return_dict (`bool`, *optional*, defaults to `True`):
155
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
156
plain tuple.
157
callback (`Callable`, *optional*):
158
A function that will be called every `callback_steps` steps during inference. The function will be
159
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
160
callback_steps (`int`, *optional*, defaults to 1):
161
The frequency at which the `callback` function will be called. If not specified, the callback will be
162
called at every step.
163
164
Returns:
165
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
166
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
167
When returning a tuple, the first element is a list with the generated images, and the second element is a
168
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
169
(nsfw) content, according to the `safety_checker`.
170
"""
171
172
if isinstance(prompt, str):
173
batch_size = 1
174
elif isinstance(prompt, list):
175
batch_size = len(prompt)
176
else:
177
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
178
179
if height % 8 != 0 or width % 8 != 0:
180
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
181
182
if (callback_steps is None) or (
183
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
184
):
185
raise ValueError(
186
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
187
f" {type(callback_steps)}."
188
)
189
190
# get prompt text embeddings
191
text_inputs = self.tokenizer(
192
prompt,
193
padding="max_length",
194
max_length=self.tokenizer.model_max_length,
195
return_tensors="pt",
196
)
197
text_input_ids = text_inputs.input_ids
198
199
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
200
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
201
logger.warning(
202
"The following part of your input was truncated because CLIP can only handle sequences up to"
203
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
204
)
205
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
206
207
if text_embeddings is None:
208
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
209
210
# duplicate text embeddings for each generation per prompt, using mps friendly method
211
bs_embed, seq_len, _ = text_embeddings.shape
212
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
213
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
214
215
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
216
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
217
# corresponds to doing no classifier free guidance.
218
do_classifier_free_guidance = guidance_scale > 1.0
219
# get unconditional embeddings for classifier free guidance
220
if do_classifier_free_guidance:
221
uncond_tokens: List[str]
222
if negative_prompt is None:
223
uncond_tokens = [""]
224
elif type(prompt) is not type(negative_prompt):
225
raise TypeError(
226
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
227
f" {type(prompt)}."
228
)
229
elif isinstance(negative_prompt, str):
230
uncond_tokens = [negative_prompt]
231
elif batch_size != len(negative_prompt):
232
raise ValueError(
233
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
234
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
235
" the batch size of `prompt`."
236
)
237
else:
238
uncond_tokens = negative_prompt
239
240
max_length = text_input_ids.shape[-1]
241
uncond_input = self.tokenizer(
242
uncond_tokens,
243
padding="max_length",
244
max_length=max_length,
245
truncation=True,
246
return_tensors="pt",
247
)
248
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
249
250
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
251
seq_len = uncond_embeddings.shape[1]
252
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
253
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
254
255
# For classifier free guidance, we need to do two forward passes.
256
# Here we concatenate the unconditional and text embeddings into a single batch
257
# to avoid doing two forward passes
258
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
259
260
# get the initial random noise unless the user supplied it
261
262
# Unlike in other pipelines, latents need to be generated in the target device
263
# for 1-to-1 results reproducibility with the CompVis implementation.
264
# However this currently doesn't work in `mps`.
265
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
266
latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.in_channels, 64, 64)
267
latents_dtype = text_embeddings.dtype
268
if latents is None:
269
if self.device.type == "mps":
270
# randn does not exist on mps
271
latents_reference = torch.randn(
272
latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype
273
).to(self.device)
274
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
275
self.device
276
)
277
else:
278
latents_reference = torch.randn(
279
latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype
280
)
281
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
282
else:
283
if latents_reference.shape != latents_shape:
284
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
285
latents_reference = latents_reference.to(self.device)
286
latents = latents.to(self.device)
287
288
# This is the key part of the pipeline where we
289
# try to ensure that the generated images w/ the same seed
290
# but different sizes actually result in similar images
291
dx = (latents_shape[3] - latents_shape_reference[3]) // 2
292
dy = (latents_shape[2] - latents_shape_reference[2]) // 2
293
w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx
294
h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy
295
tx = 0 if dx < 0 else dx
296
ty = 0 if dy < 0 else dy
297
dx = max(-dx, 0)
298
dy = max(-dy, 0)
299
# import pdb
300
# pdb.set_trace()
301
latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]
302
303
# set timesteps
304
self.scheduler.set_timesteps(num_inference_steps)
305
306
# Some schedulers like PNDM have timesteps as arrays
307
# It's more optimized to move all timesteps to correct device beforehand
308
timesteps_tensor = self.scheduler.timesteps.to(self.device)
309
310
# scale the initial noise by the standard deviation required by the scheduler
311
latents = latents * self.scheduler.init_noise_sigma
312
313
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
314
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
315
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
316
# and should be between [0, 1]
317
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
318
extra_step_kwargs = {}
319
if accepts_eta:
320
extra_step_kwargs["eta"] = eta
321
322
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
323
# expand the latents if we are doing classifier free guidance
324
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
325
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
326
327
# predict the noise residual
328
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
329
330
# perform guidance
331
if do_classifier_free_guidance:
332
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
333
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
334
335
# compute the previous noisy sample x_t -> x_t-1
336
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
337
338
# call the callback, if provided
339
if callback is not None and i % callback_steps == 0:
340
callback(i, t, latents)
341
342
latents = 1 / 0.18215 * latents
343
image = self.vae.decode(latents).sample
344
345
image = (image / 2 + 0.5).clamp(0, 1)
346
347
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
348
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
349
350
if self.safety_checker is not None:
351
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
352
self.device
353
)
354
image, has_nsfw_concept = self.safety_checker(
355
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
356
)
357
else:
358
has_nsfw_concept = None
359
360
if output_type == "pil":
361
image = self.numpy_to_pil(image)
362
363
if not return_dict:
364
return (image, has_nsfw_concept)
365
366
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
367
368