Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/lpw_stable_diffusion.py
1448 views
1
import inspect
2
import re
3
from typing import Callable, List, Optional, Union
4
5
import numpy as np
6
import PIL
7
import torch
8
from packaging import version
9
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
11
import diffusers
12
from diffusers import SchedulerMixin, StableDiffusionPipeline
13
from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
15
from diffusers.utils import logging
16
17
18
try:
19
from diffusers.utils import PIL_INTERPOLATION
20
except ImportError:
21
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
22
PIL_INTERPOLATION = {
23
"linear": PIL.Image.Resampling.BILINEAR,
24
"bilinear": PIL.Image.Resampling.BILINEAR,
25
"bicubic": PIL.Image.Resampling.BICUBIC,
26
"lanczos": PIL.Image.Resampling.LANCZOS,
27
"nearest": PIL.Image.Resampling.NEAREST,
28
}
29
else:
30
PIL_INTERPOLATION = {
31
"linear": PIL.Image.LINEAR,
32
"bilinear": PIL.Image.BILINEAR,
33
"bicubic": PIL.Image.BICUBIC,
34
"lanczos": PIL.Image.LANCZOS,
35
"nearest": PIL.Image.NEAREST,
36
}
37
# ------------------------------------------------------------------------------
38
39
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
41
re_attention = re.compile(
42
r"""
43
\\\(|
44
\\\)|
45
\\\[|
46
\\]|
47
\\\\|
48
\\|
49
\(|
50
\[|
51
:([+-]?[.\d]+)\)|
52
\)|
53
]|
54
[^\\()\[\]:]+|
55
:
56
""",
57
re.X,
58
)
59
60
61
def parse_prompt_attention(text):
62
"""
63
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
64
Accepted tokens are:
65
(abc) - increases attention to abc by a multiplier of 1.1
66
(abc:3.12) - increases attention to abc by a multiplier of 3.12
67
[abc] - decreases attention to abc by a multiplier of 1.1
68
\( - literal character '('
69
\[ - literal character '['
70
\) - literal character ')'
71
\] - literal character ']'
72
\\ - literal character '\'
73
anything else - just text
74
>>> parse_prompt_attention('normal text')
75
[['normal text', 1.0]]
76
>>> parse_prompt_attention('an (important) word')
77
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
78
>>> parse_prompt_attention('(unbalanced')
79
[['unbalanced', 1.1]]
80
>>> parse_prompt_attention('\(literal\]')
81
[['(literal]', 1.0]]
82
>>> parse_prompt_attention('(unnecessary)(parens)')
83
[['unnecessaryparens', 1.1]]
84
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
85
[['a ', 1.0],
86
['house', 1.5730000000000004],
87
[' ', 1.1],
88
['on', 1.0],
89
[' a ', 1.1],
90
['hill', 0.55],
91
[', sun, ', 1.1],
92
['sky', 1.4641000000000006],
93
['.', 1.1]]
94
"""
95
96
res = []
97
round_brackets = []
98
square_brackets = []
99
100
round_bracket_multiplier = 1.1
101
square_bracket_multiplier = 1 / 1.1
102
103
def multiply_range(start_position, multiplier):
104
for p in range(start_position, len(res)):
105
res[p][1] *= multiplier
106
107
for m in re_attention.finditer(text):
108
text = m.group(0)
109
weight = m.group(1)
110
111
if text.startswith("\\"):
112
res.append([text[1:], 1.0])
113
elif text == "(":
114
round_brackets.append(len(res))
115
elif text == "[":
116
square_brackets.append(len(res))
117
elif weight is not None and len(round_brackets) > 0:
118
multiply_range(round_brackets.pop(), float(weight))
119
elif text == ")" and len(round_brackets) > 0:
120
multiply_range(round_brackets.pop(), round_bracket_multiplier)
121
elif text == "]" and len(square_brackets) > 0:
122
multiply_range(square_brackets.pop(), square_bracket_multiplier)
123
else:
124
res.append([text, 1.0])
125
126
for pos in round_brackets:
127
multiply_range(pos, round_bracket_multiplier)
128
129
for pos in square_brackets:
130
multiply_range(pos, square_bracket_multiplier)
131
132
if len(res) == 0:
133
res = [["", 1.0]]
134
135
# merge runs of identical weights
136
i = 0
137
while i + 1 < len(res):
138
if res[i][1] == res[i + 1][1]:
139
res[i][0] += res[i + 1][0]
140
res.pop(i + 1)
141
else:
142
i += 1
143
144
return res
145
146
147
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
148
r"""
149
Tokenize a list of prompts and return its tokens with weights of each token.
150
151
No padding, starting or ending token is included.
152
"""
153
tokens = []
154
weights = []
155
truncated = False
156
for text in prompt:
157
texts_and_weights = parse_prompt_attention(text)
158
text_token = []
159
text_weight = []
160
for word, weight in texts_and_weights:
161
# tokenize and discard the starting and the ending token
162
token = pipe.tokenizer(word).input_ids[1:-1]
163
text_token += token
164
# copy the weight by length of token
165
text_weight += [weight] * len(token)
166
# stop if the text is too long (longer than truncation limit)
167
if len(text_token) > max_length:
168
truncated = True
169
break
170
# truncate
171
if len(text_token) > max_length:
172
truncated = True
173
text_token = text_token[:max_length]
174
text_weight = text_weight[:max_length]
175
tokens.append(text_token)
176
weights.append(text_weight)
177
if truncated:
178
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
179
return tokens, weights
180
181
182
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
183
r"""
184
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
185
"""
186
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
187
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
188
for i in range(len(tokens)):
189
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
190
if no_boseos_middle:
191
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
192
else:
193
w = []
194
if len(weights[i]) == 0:
195
w = [1.0] * weights_length
196
else:
197
for j in range(max_embeddings_multiples):
198
w.append(1.0) # weight for starting token in this chunk
199
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
200
w.append(1.0) # weight for ending token in this chunk
201
w += [1.0] * (weights_length - len(w))
202
weights[i] = w[:]
203
204
return tokens, weights
205
206
207
def get_unweighted_text_embeddings(
208
pipe: StableDiffusionPipeline,
209
text_input: torch.Tensor,
210
chunk_length: int,
211
no_boseos_middle: Optional[bool] = True,
212
):
213
"""
214
When the length of tokens is a multiple of the capacity of the text encoder,
215
it should be split into chunks and sent to the text encoder individually.
216
"""
217
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
218
if max_embeddings_multiples > 1:
219
text_embeddings = []
220
for i in range(max_embeddings_multiples):
221
# extract the i-th chunk
222
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
223
224
# cover the head and the tail by the starting and the ending tokens
225
text_input_chunk[:, 0] = text_input[0, 0]
226
text_input_chunk[:, -1] = text_input[0, -1]
227
text_embedding = pipe.text_encoder(text_input_chunk)[0]
228
229
if no_boseos_middle:
230
if i == 0:
231
# discard the ending token
232
text_embedding = text_embedding[:, :-1]
233
elif i == max_embeddings_multiples - 1:
234
# discard the starting token
235
text_embedding = text_embedding[:, 1:]
236
else:
237
# discard both starting and ending tokens
238
text_embedding = text_embedding[:, 1:-1]
239
240
text_embeddings.append(text_embedding)
241
text_embeddings = torch.concat(text_embeddings, axis=1)
242
else:
243
text_embeddings = pipe.text_encoder(text_input)[0]
244
return text_embeddings
245
246
247
def get_weighted_text_embeddings(
248
pipe: StableDiffusionPipeline,
249
prompt: Union[str, List[str]],
250
uncond_prompt: Optional[Union[str, List[str]]] = None,
251
max_embeddings_multiples: Optional[int] = 3,
252
no_boseos_middle: Optional[bool] = False,
253
skip_parsing: Optional[bool] = False,
254
skip_weighting: Optional[bool] = False,
255
):
256
r"""
257
Prompts can be assigned with local weights using brackets. For example,
258
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
259
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
260
261
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
262
263
Args:
264
pipe (`StableDiffusionPipeline`):
265
Pipe to provide access to the tokenizer and the text encoder.
266
prompt (`str` or `List[str]`):
267
The prompt or prompts to guide the image generation.
268
uncond_prompt (`str` or `List[str]`):
269
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
270
is provided, the embeddings of prompt and uncond_prompt are concatenated.
271
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
272
The max multiple length of prompt embeddings compared to the max output length of text encoder.
273
no_boseos_middle (`bool`, *optional*, defaults to `False`):
274
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
275
ending token in each of the chunk in the middle.
276
skip_parsing (`bool`, *optional*, defaults to `False`):
277
Skip the parsing of brackets.
278
skip_weighting (`bool`, *optional*, defaults to `False`):
279
Skip the weighting. When the parsing is skipped, it is forced True.
280
"""
281
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
282
if isinstance(prompt, str):
283
prompt = [prompt]
284
285
if not skip_parsing:
286
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
287
if uncond_prompt is not None:
288
if isinstance(uncond_prompt, str):
289
uncond_prompt = [uncond_prompt]
290
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
291
else:
292
prompt_tokens = [
293
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
294
]
295
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
296
if uncond_prompt is not None:
297
if isinstance(uncond_prompt, str):
298
uncond_prompt = [uncond_prompt]
299
uncond_tokens = [
300
token[1:-1]
301
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
302
]
303
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
304
305
# round up the longest length of tokens to a multiple of (model_max_length - 2)
306
max_length = max([len(token) for token in prompt_tokens])
307
if uncond_prompt is not None:
308
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
309
310
max_embeddings_multiples = min(
311
max_embeddings_multiples,
312
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
313
)
314
max_embeddings_multiples = max(1, max_embeddings_multiples)
315
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
316
317
# pad the length of tokens and weights
318
bos = pipe.tokenizer.bos_token_id
319
eos = pipe.tokenizer.eos_token_id
320
prompt_tokens, prompt_weights = pad_tokens_and_weights(
321
prompt_tokens,
322
prompt_weights,
323
max_length,
324
bos,
325
eos,
326
no_boseos_middle=no_boseos_middle,
327
chunk_length=pipe.tokenizer.model_max_length,
328
)
329
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
330
if uncond_prompt is not None:
331
uncond_tokens, uncond_weights = pad_tokens_and_weights(
332
uncond_tokens,
333
uncond_weights,
334
max_length,
335
bos,
336
eos,
337
no_boseos_middle=no_boseos_middle,
338
chunk_length=pipe.tokenizer.model_max_length,
339
)
340
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
341
342
# get the embeddings
343
text_embeddings = get_unweighted_text_embeddings(
344
pipe,
345
prompt_tokens,
346
pipe.tokenizer.model_max_length,
347
no_boseos_middle=no_boseos_middle,
348
)
349
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
350
if uncond_prompt is not None:
351
uncond_embeddings = get_unweighted_text_embeddings(
352
pipe,
353
uncond_tokens,
354
pipe.tokenizer.model_max_length,
355
no_boseos_middle=no_boseos_middle,
356
)
357
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
358
359
# assign weights to the prompts and normalize in the sense of mean
360
# TODO: should we normalize by chunk or in a whole (current implementation)?
361
if (not skip_parsing) and (not skip_weighting):
362
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
363
text_embeddings *= prompt_weights.unsqueeze(-1)
364
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
365
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
366
if uncond_prompt is not None:
367
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
368
uncond_embeddings *= uncond_weights.unsqueeze(-1)
369
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
370
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
371
372
if uncond_prompt is not None:
373
return text_embeddings, uncond_embeddings
374
return text_embeddings, None
375
376
377
def preprocess_image(image):
378
w, h = image.size
379
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
380
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
381
image = np.array(image).astype(np.float32) / 255.0
382
image = image[None].transpose(0, 3, 1, 2)
383
image = torch.from_numpy(image)
384
return 2.0 * image - 1.0
385
386
387
def preprocess_mask(mask, scale_factor=8):
388
mask = mask.convert("L")
389
w, h = mask.size
390
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
391
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
392
mask = np.array(mask).astype(np.float32) / 255.0
393
mask = np.tile(mask, (4, 1, 1))
394
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
395
mask = 1 - mask # repaint white, keep black
396
mask = torch.from_numpy(mask)
397
return mask
398
399
400
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
401
r"""
402
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
403
weighting in prompt.
404
405
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
406
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
407
408
Args:
409
vae ([`AutoencoderKL`]):
410
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
411
text_encoder ([`CLIPTextModel`]):
412
Frozen text-encoder. Stable Diffusion uses the text portion of
413
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
414
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
415
tokenizer (`CLIPTokenizer`):
416
Tokenizer of class
417
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
418
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
419
scheduler ([`SchedulerMixin`]):
420
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
421
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
422
safety_checker ([`StableDiffusionSafetyChecker`]):
423
Classification module that estimates whether generated images could be considered offensive or harmful.
424
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
425
feature_extractor ([`CLIPImageProcessor`]):
426
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
427
"""
428
429
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
430
431
def __init__(
432
self,
433
vae: AutoencoderKL,
434
text_encoder: CLIPTextModel,
435
tokenizer: CLIPTokenizer,
436
unet: UNet2DConditionModel,
437
scheduler: SchedulerMixin,
438
safety_checker: StableDiffusionSafetyChecker,
439
feature_extractor: CLIPImageProcessor,
440
requires_safety_checker: bool = True,
441
):
442
super().__init__(
443
vae=vae,
444
text_encoder=text_encoder,
445
tokenizer=tokenizer,
446
unet=unet,
447
scheduler=scheduler,
448
safety_checker=safety_checker,
449
feature_extractor=feature_extractor,
450
requires_safety_checker=requires_safety_checker,
451
)
452
self.__init__additional__()
453
454
else:
455
456
def __init__(
457
self,
458
vae: AutoencoderKL,
459
text_encoder: CLIPTextModel,
460
tokenizer: CLIPTokenizer,
461
unet: UNet2DConditionModel,
462
scheduler: SchedulerMixin,
463
safety_checker: StableDiffusionSafetyChecker,
464
feature_extractor: CLIPImageProcessor,
465
):
466
super().__init__(
467
vae=vae,
468
text_encoder=text_encoder,
469
tokenizer=tokenizer,
470
unet=unet,
471
scheduler=scheduler,
472
safety_checker=safety_checker,
473
feature_extractor=feature_extractor,
474
)
475
self.__init__additional__()
476
477
def __init__additional__(self):
478
if not hasattr(self, "vae_scale_factor"):
479
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
480
481
@property
482
def _execution_device(self):
483
r"""
484
Returns the device on which the pipeline's models will be executed. After calling
485
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
486
hooks.
487
"""
488
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
489
return self.device
490
for module in self.unet.modules():
491
if (
492
hasattr(module, "_hf_hook")
493
and hasattr(module._hf_hook, "execution_device")
494
and module._hf_hook.execution_device is not None
495
):
496
return torch.device(module._hf_hook.execution_device)
497
return self.device
498
499
def _encode_prompt(
500
self,
501
prompt,
502
device,
503
num_images_per_prompt,
504
do_classifier_free_guidance,
505
negative_prompt,
506
max_embeddings_multiples,
507
):
508
r"""
509
Encodes the prompt into text encoder hidden states.
510
511
Args:
512
prompt (`str` or `list(int)`):
513
prompt to be encoded
514
device: (`torch.device`):
515
torch device
516
num_images_per_prompt (`int`):
517
number of images that should be generated per prompt
518
do_classifier_free_guidance (`bool`):
519
whether to use classifier free guidance or not
520
negative_prompt (`str` or `List[str]`):
521
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
522
if `guidance_scale` is less than `1`).
523
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
524
The max multiple length of prompt embeddings compared to the max output length of text encoder.
525
"""
526
batch_size = len(prompt) if isinstance(prompt, list) else 1
527
528
if negative_prompt is None:
529
negative_prompt = [""] * batch_size
530
elif isinstance(negative_prompt, str):
531
negative_prompt = [negative_prompt] * batch_size
532
if batch_size != len(negative_prompt):
533
raise ValueError(
534
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
535
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
536
" the batch size of `prompt`."
537
)
538
539
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
540
pipe=self,
541
prompt=prompt,
542
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
543
max_embeddings_multiples=max_embeddings_multiples,
544
)
545
bs_embed, seq_len, _ = text_embeddings.shape
546
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
547
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
548
549
if do_classifier_free_guidance:
550
bs_embed, seq_len, _ = uncond_embeddings.shape
551
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
552
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
553
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
554
555
return text_embeddings
556
557
def check_inputs(self, prompt, height, width, strength, callback_steps):
558
if not isinstance(prompt, str) and not isinstance(prompt, list):
559
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
560
561
if strength < 0 or strength > 1:
562
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
563
564
if height % 8 != 0 or width % 8 != 0:
565
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
566
567
if (callback_steps is None) or (
568
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
569
):
570
raise ValueError(
571
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
572
f" {type(callback_steps)}."
573
)
574
575
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
576
if is_text2img:
577
return self.scheduler.timesteps.to(device), num_inference_steps
578
else:
579
# get the original timestep using init_timestep
580
offset = self.scheduler.config.get("steps_offset", 0)
581
init_timestep = int(num_inference_steps * strength) + offset
582
init_timestep = min(init_timestep, num_inference_steps)
583
584
t_start = max(num_inference_steps - init_timestep + offset, 0)
585
timesteps = self.scheduler.timesteps[t_start:].to(device)
586
return timesteps, num_inference_steps - t_start
587
588
def run_safety_checker(self, image, device, dtype):
589
if self.safety_checker is not None:
590
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
591
image, has_nsfw_concept = self.safety_checker(
592
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
593
)
594
else:
595
has_nsfw_concept = None
596
return image, has_nsfw_concept
597
598
def decode_latents(self, latents):
599
latents = 1 / 0.18215 * latents
600
image = self.vae.decode(latents).sample
601
image = (image / 2 + 0.5).clamp(0, 1)
602
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
603
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
604
return image
605
606
def prepare_extra_step_kwargs(self, generator, eta):
607
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
608
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
609
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
610
# and should be between [0, 1]
611
612
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
613
extra_step_kwargs = {}
614
if accepts_eta:
615
extra_step_kwargs["eta"] = eta
616
617
# check if the scheduler accepts generator
618
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
619
if accepts_generator:
620
extra_step_kwargs["generator"] = generator
621
return extra_step_kwargs
622
623
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
624
if image is None:
625
shape = (
626
batch_size,
627
self.unet.in_channels,
628
height // self.vae_scale_factor,
629
width // self.vae_scale_factor,
630
)
631
632
if latents is None:
633
if device.type == "mps":
634
# randn does not work reproducibly on mps
635
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
636
else:
637
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
638
else:
639
if latents.shape != shape:
640
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
641
latents = latents.to(device)
642
643
# scale the initial noise by the standard deviation required by the scheduler
644
latents = latents * self.scheduler.init_noise_sigma
645
return latents, None, None
646
else:
647
init_latent_dist = self.vae.encode(image).latent_dist
648
init_latents = init_latent_dist.sample(generator=generator)
649
init_latents = 0.18215 * init_latents
650
init_latents = torch.cat([init_latents] * batch_size, dim=0)
651
init_latents_orig = init_latents
652
shape = init_latents.shape
653
654
# add noise to latents using the timesteps
655
if device.type == "mps":
656
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
657
else:
658
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
659
latents = self.scheduler.add_noise(init_latents, noise, timestep)
660
return latents, init_latents_orig, noise
661
662
@torch.no_grad()
663
def __call__(
664
self,
665
prompt: Union[str, List[str]],
666
negative_prompt: Optional[Union[str, List[str]]] = None,
667
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
668
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
669
height: int = 512,
670
width: int = 512,
671
num_inference_steps: int = 50,
672
guidance_scale: float = 7.5,
673
strength: float = 0.8,
674
num_images_per_prompt: Optional[int] = 1,
675
eta: float = 0.0,
676
generator: Optional[torch.Generator] = None,
677
latents: Optional[torch.FloatTensor] = None,
678
max_embeddings_multiples: Optional[int] = 3,
679
output_type: Optional[str] = "pil",
680
return_dict: bool = True,
681
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
682
is_cancelled_callback: Optional[Callable[[], bool]] = None,
683
callback_steps: int = 1,
684
):
685
r"""
686
Function invoked when calling the pipeline for generation.
687
688
Args:
689
prompt (`str` or `List[str]`):
690
The prompt or prompts to guide the image generation.
691
negative_prompt (`str` or `List[str]`, *optional*):
692
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
693
if `guidance_scale` is less than `1`).
694
image (`torch.FloatTensor` or `PIL.Image.Image`):
695
`Image`, or tensor representing an image batch, that will be used as the starting point for the
696
process.
697
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
698
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
699
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
700
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
701
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
702
height (`int`, *optional*, defaults to 512):
703
The height in pixels of the generated image.
704
width (`int`, *optional*, defaults to 512):
705
The width in pixels of the generated image.
706
num_inference_steps (`int`, *optional*, defaults to 50):
707
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
708
expense of slower inference.
709
guidance_scale (`float`, *optional*, defaults to 7.5):
710
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
711
`guidance_scale` is defined as `w` of equation 2. of [Imagen
712
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
713
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
714
usually at the expense of lower image quality.
715
strength (`float`, *optional*, defaults to 0.8):
716
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
717
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
718
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
719
noise will be maximum and the denoising process will run for the full number of iterations specified in
720
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
721
num_images_per_prompt (`int`, *optional*, defaults to 1):
722
The number of images to generate per prompt.
723
eta (`float`, *optional*, defaults to 0.0):
724
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
725
[`schedulers.DDIMScheduler`], will be ignored for others.
726
generator (`torch.Generator`, *optional*):
727
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
728
deterministic.
729
latents (`torch.FloatTensor`, *optional*):
730
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
731
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
732
tensor will ge generated by sampling using the supplied random `generator`.
733
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
734
The max multiple length of prompt embeddings compared to the max output length of text encoder.
735
output_type (`str`, *optional*, defaults to `"pil"`):
736
The output format of the generate image. Choose between
737
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
738
return_dict (`bool`, *optional*, defaults to `True`):
739
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
740
plain tuple.
741
callback (`Callable`, *optional*):
742
A function that will be called every `callback_steps` steps during inference. The function will be
743
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
744
is_cancelled_callback (`Callable`, *optional*):
745
A function that will be called every `callback_steps` steps during inference. If the function returns
746
`True`, the inference will be cancelled.
747
callback_steps (`int`, *optional*, defaults to 1):
748
The frequency at which the `callback` function will be called. If not specified, the callback will be
749
called at every step.
750
751
Returns:
752
`None` if cancelled by `is_cancelled_callback`,
753
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
754
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
755
When returning a tuple, the first element is a list with the generated images, and the second element is a
756
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
757
(nsfw) content, according to the `safety_checker`.
758
"""
759
# 0. Default height and width to unet
760
height = height or self.unet.config.sample_size * self.vae_scale_factor
761
width = width or self.unet.config.sample_size * self.vae_scale_factor
762
763
# 1. Check inputs. Raise error if not correct
764
self.check_inputs(prompt, height, width, strength, callback_steps)
765
766
# 2. Define call parameters
767
batch_size = 1 if isinstance(prompt, str) else len(prompt)
768
device = self._execution_device
769
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
770
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
771
# corresponds to doing no classifier free guidance.
772
do_classifier_free_guidance = guidance_scale > 1.0
773
774
# 3. Encode input prompt
775
text_embeddings = self._encode_prompt(
776
prompt,
777
device,
778
num_images_per_prompt,
779
do_classifier_free_guidance,
780
negative_prompt,
781
max_embeddings_multiples,
782
)
783
dtype = text_embeddings.dtype
784
785
# 4. Preprocess image and mask
786
if isinstance(image, PIL.Image.Image):
787
image = preprocess_image(image)
788
if image is not None:
789
image = image.to(device=self.device, dtype=dtype)
790
if isinstance(mask_image, PIL.Image.Image):
791
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
792
if mask_image is not None:
793
mask = mask_image.to(device=self.device, dtype=dtype)
794
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
795
else:
796
mask = None
797
798
# 5. set timesteps
799
self.scheduler.set_timesteps(num_inference_steps, device=device)
800
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
801
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
802
803
# 6. Prepare latent variables
804
latents, init_latents_orig, noise = self.prepare_latents(
805
image,
806
latent_timestep,
807
batch_size * num_images_per_prompt,
808
height,
809
width,
810
dtype,
811
device,
812
generator,
813
latents,
814
)
815
816
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
817
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
818
819
# 8. Denoising loop
820
for i, t in enumerate(self.progress_bar(timesteps)):
821
# expand the latents if we are doing classifier free guidance
822
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
823
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
824
825
# predict the noise residual
826
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
827
828
# perform guidance
829
if do_classifier_free_guidance:
830
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
831
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
832
833
# compute the previous noisy sample x_t -> x_t-1
834
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
835
836
if mask is not None:
837
# masking
838
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
839
latents = (init_latents_proper * mask) + (latents * (1 - mask))
840
841
# call the callback, if provided
842
if i % callback_steps == 0:
843
if callback is not None:
844
callback(i, t, latents)
845
if is_cancelled_callback is not None and is_cancelled_callback():
846
return None
847
848
# 9. Post-processing
849
image = self.decode_latents(latents)
850
851
# 10. Run safety checker
852
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
853
854
# 11. Convert to PIL
855
if output_type == "pil":
856
image = self.numpy_to_pil(image)
857
858
if not return_dict:
859
return image, has_nsfw_concept
860
861
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
862
863
def text2img(
864
self,
865
prompt: Union[str, List[str]],
866
negative_prompt: Optional[Union[str, List[str]]] = None,
867
height: int = 512,
868
width: int = 512,
869
num_inference_steps: int = 50,
870
guidance_scale: float = 7.5,
871
num_images_per_prompt: Optional[int] = 1,
872
eta: float = 0.0,
873
generator: Optional[torch.Generator] = None,
874
latents: Optional[torch.FloatTensor] = None,
875
max_embeddings_multiples: Optional[int] = 3,
876
output_type: Optional[str] = "pil",
877
return_dict: bool = True,
878
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
879
is_cancelled_callback: Optional[Callable[[], bool]] = None,
880
callback_steps: int = 1,
881
):
882
r"""
883
Function for text-to-image generation.
884
Args:
885
prompt (`str` or `List[str]`):
886
The prompt or prompts to guide the image generation.
887
negative_prompt (`str` or `List[str]`, *optional*):
888
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
889
if `guidance_scale` is less than `1`).
890
height (`int`, *optional*, defaults to 512):
891
The height in pixels of the generated image.
892
width (`int`, *optional*, defaults to 512):
893
The width in pixels of the generated image.
894
num_inference_steps (`int`, *optional*, defaults to 50):
895
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
896
expense of slower inference.
897
guidance_scale (`float`, *optional*, defaults to 7.5):
898
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
899
`guidance_scale` is defined as `w` of equation 2. of [Imagen
900
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
901
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
902
usually at the expense of lower image quality.
903
num_images_per_prompt (`int`, *optional*, defaults to 1):
904
The number of images to generate per prompt.
905
eta (`float`, *optional*, defaults to 0.0):
906
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
907
[`schedulers.DDIMScheduler`], will be ignored for others.
908
generator (`torch.Generator`, *optional*):
909
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
910
deterministic.
911
latents (`torch.FloatTensor`, *optional*):
912
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
913
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
914
tensor will ge generated by sampling using the supplied random `generator`.
915
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
916
The max multiple length of prompt embeddings compared to the max output length of text encoder.
917
output_type (`str`, *optional*, defaults to `"pil"`):
918
The output format of the generate image. Choose between
919
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
920
return_dict (`bool`, *optional*, defaults to `True`):
921
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
922
plain tuple.
923
callback (`Callable`, *optional*):
924
A function that will be called every `callback_steps` steps during inference. The function will be
925
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
926
is_cancelled_callback (`Callable`, *optional*):
927
A function that will be called every `callback_steps` steps during inference. If the function returns
928
`True`, the inference will be cancelled.
929
callback_steps (`int`, *optional*, defaults to 1):
930
The frequency at which the `callback` function will be called. If not specified, the callback will be
931
called at every step.
932
Returns:
933
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
934
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
935
When returning a tuple, the first element is a list with the generated images, and the second element is a
936
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
937
(nsfw) content, according to the `safety_checker`.
938
"""
939
return self.__call__(
940
prompt=prompt,
941
negative_prompt=negative_prompt,
942
height=height,
943
width=width,
944
num_inference_steps=num_inference_steps,
945
guidance_scale=guidance_scale,
946
num_images_per_prompt=num_images_per_prompt,
947
eta=eta,
948
generator=generator,
949
latents=latents,
950
max_embeddings_multiples=max_embeddings_multiples,
951
output_type=output_type,
952
return_dict=return_dict,
953
callback=callback,
954
is_cancelled_callback=is_cancelled_callback,
955
callback_steps=callback_steps,
956
)
957
958
def img2img(
959
self,
960
image: Union[torch.FloatTensor, PIL.Image.Image],
961
prompt: Union[str, List[str]],
962
negative_prompt: Optional[Union[str, List[str]]] = None,
963
strength: float = 0.8,
964
num_inference_steps: Optional[int] = 50,
965
guidance_scale: Optional[float] = 7.5,
966
num_images_per_prompt: Optional[int] = 1,
967
eta: Optional[float] = 0.0,
968
generator: Optional[torch.Generator] = None,
969
max_embeddings_multiples: Optional[int] = 3,
970
output_type: Optional[str] = "pil",
971
return_dict: bool = True,
972
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
973
is_cancelled_callback: Optional[Callable[[], bool]] = None,
974
callback_steps: int = 1,
975
):
976
r"""
977
Function for image-to-image generation.
978
Args:
979
image (`torch.FloatTensor` or `PIL.Image.Image`):
980
`Image`, or tensor representing an image batch, that will be used as the starting point for the
981
process.
982
prompt (`str` or `List[str]`):
983
The prompt or prompts to guide the image generation.
984
negative_prompt (`str` or `List[str]`, *optional*):
985
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
986
if `guidance_scale` is less than `1`).
987
strength (`float`, *optional*, defaults to 0.8):
988
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
989
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
990
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
991
noise will be maximum and the denoising process will run for the full number of iterations specified in
992
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
993
num_inference_steps (`int`, *optional*, defaults to 50):
994
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
995
expense of slower inference. This parameter will be modulated by `strength`.
996
guidance_scale (`float`, *optional*, defaults to 7.5):
997
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
998
`guidance_scale` is defined as `w` of equation 2. of [Imagen
999
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1000
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1001
usually at the expense of lower image quality.
1002
num_images_per_prompt (`int`, *optional*, defaults to 1):
1003
The number of images to generate per prompt.
1004
eta (`float`, *optional*, defaults to 0.0):
1005
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1006
[`schedulers.DDIMScheduler`], will be ignored for others.
1007
generator (`torch.Generator`, *optional*):
1008
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1009
deterministic.
1010
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1011
The max multiple length of prompt embeddings compared to the max output length of text encoder.
1012
output_type (`str`, *optional*, defaults to `"pil"`):
1013
The output format of the generate image. Choose between
1014
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1015
return_dict (`bool`, *optional*, defaults to `True`):
1016
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1017
plain tuple.
1018
callback (`Callable`, *optional*):
1019
A function that will be called every `callback_steps` steps during inference. The function will be
1020
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1021
is_cancelled_callback (`Callable`, *optional*):
1022
A function that will be called every `callback_steps` steps during inference. If the function returns
1023
`True`, the inference will be cancelled.
1024
callback_steps (`int`, *optional*, defaults to 1):
1025
The frequency at which the `callback` function will be called. If not specified, the callback will be
1026
called at every step.
1027
Returns:
1028
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1029
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1030
When returning a tuple, the first element is a list with the generated images, and the second element is a
1031
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1032
(nsfw) content, according to the `safety_checker`.
1033
"""
1034
return self.__call__(
1035
prompt=prompt,
1036
negative_prompt=negative_prompt,
1037
image=image,
1038
num_inference_steps=num_inference_steps,
1039
guidance_scale=guidance_scale,
1040
strength=strength,
1041
num_images_per_prompt=num_images_per_prompt,
1042
eta=eta,
1043
generator=generator,
1044
max_embeddings_multiples=max_embeddings_multiples,
1045
output_type=output_type,
1046
return_dict=return_dict,
1047
callback=callback,
1048
is_cancelled_callback=is_cancelled_callback,
1049
callback_steps=callback_steps,
1050
)
1051
1052
def inpaint(
1053
self,
1054
image: Union[torch.FloatTensor, PIL.Image.Image],
1055
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1056
prompt: Union[str, List[str]],
1057
negative_prompt: Optional[Union[str, List[str]]] = None,
1058
strength: float = 0.8,
1059
num_inference_steps: Optional[int] = 50,
1060
guidance_scale: Optional[float] = 7.5,
1061
num_images_per_prompt: Optional[int] = 1,
1062
eta: Optional[float] = 0.0,
1063
generator: Optional[torch.Generator] = None,
1064
max_embeddings_multiples: Optional[int] = 3,
1065
output_type: Optional[str] = "pil",
1066
return_dict: bool = True,
1067
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1068
is_cancelled_callback: Optional[Callable[[], bool]] = None,
1069
callback_steps: int = 1,
1070
):
1071
r"""
1072
Function for inpaint.
1073
Args:
1074
image (`torch.FloatTensor` or `PIL.Image.Image`):
1075
`Image`, or tensor representing an image batch, that will be used as the starting point for the
1076
process. This is the image whose masked region will be inpainted.
1077
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1078
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1079
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1080
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1081
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1082
prompt (`str` or `List[str]`):
1083
The prompt or prompts to guide the image generation.
1084
negative_prompt (`str` or `List[str]`, *optional*):
1085
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1086
if `guidance_scale` is less than `1`).
1087
strength (`float`, *optional*, defaults to 0.8):
1088
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1089
is 1, the denoising process will be run on the masked area for the full number of iterations specified
1090
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1091
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1092
num_inference_steps (`int`, *optional*, defaults to 50):
1093
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1094
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1095
guidance_scale (`float`, *optional*, defaults to 7.5):
1096
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1097
`guidance_scale` is defined as `w` of equation 2. of [Imagen
1098
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1099
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1100
usually at the expense of lower image quality.
1101
num_images_per_prompt (`int`, *optional*, defaults to 1):
1102
The number of images to generate per prompt.
1103
eta (`float`, *optional*, defaults to 0.0):
1104
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1105
[`schedulers.DDIMScheduler`], will be ignored for others.
1106
generator (`torch.Generator`, *optional*):
1107
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1108
deterministic.
1109
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1110
The max multiple length of prompt embeddings compared to the max output length of text encoder.
1111
output_type (`str`, *optional*, defaults to `"pil"`):
1112
The output format of the generate image. Choose between
1113
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1114
return_dict (`bool`, *optional*, defaults to `True`):
1115
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1116
plain tuple.
1117
callback (`Callable`, *optional*):
1118
A function that will be called every `callback_steps` steps during inference. The function will be
1119
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1120
is_cancelled_callback (`Callable`, *optional*):
1121
A function that will be called every `callback_steps` steps during inference. If the function returns
1122
`True`, the inference will be cancelled.
1123
callback_steps (`int`, *optional*, defaults to 1):
1124
The frequency at which the `callback` function will be called. If not specified, the callback will be
1125
called at every step.
1126
Returns:
1127
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1128
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1129
When returning a tuple, the first element is a list with the generated images, and the second element is a
1130
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1131
(nsfw) content, according to the `safety_checker`.
1132
"""
1133
return self.__call__(
1134
prompt=prompt,
1135
negative_prompt=negative_prompt,
1136
image=image,
1137
mask_image=mask_image,
1138
num_inference_steps=num_inference_steps,
1139
guidance_scale=guidance_scale,
1140
strength=strength,
1141
num_images_per_prompt=num_images_per_prompt,
1142
eta=eta,
1143
generator=generator,
1144
max_embeddings_multiples=max_embeddings_multiples,
1145
output_type=output_type,
1146
return_dict=return_dict,
1147
callback=callback,
1148
is_cancelled_callback=is_cancelled_callback,
1149
callback_steps=callback_steps,
1150
)
1151
1152