Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/lpw_stable_diffusion_onnx.py
1448 views
1
import inspect
2
import re
3
from typing import Callable, List, Optional, Union
4
5
import numpy as np
6
import PIL
7
import torch
8
from packaging import version
9
from transformers import CLIPImageProcessor, CLIPTokenizer
10
11
import diffusers
12
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
13
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
from diffusers.utils import logging
15
16
17
try:
18
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
19
except ImportError:
20
ORT_TO_NP_TYPE = {
21
"tensor(bool)": np.bool_,
22
"tensor(int8)": np.int8,
23
"tensor(uint8)": np.uint8,
24
"tensor(int16)": np.int16,
25
"tensor(uint16)": np.uint16,
26
"tensor(int32)": np.int32,
27
"tensor(uint32)": np.uint32,
28
"tensor(int64)": np.int64,
29
"tensor(uint64)": np.uint64,
30
"tensor(float16)": np.float16,
31
"tensor(float)": np.float32,
32
"tensor(double)": np.float64,
33
}
34
35
try:
36
from diffusers.utils import PIL_INTERPOLATION
37
except ImportError:
38
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
39
PIL_INTERPOLATION = {
40
"linear": PIL.Image.Resampling.BILINEAR,
41
"bilinear": PIL.Image.Resampling.BILINEAR,
42
"bicubic": PIL.Image.Resampling.BICUBIC,
43
"lanczos": PIL.Image.Resampling.LANCZOS,
44
"nearest": PIL.Image.Resampling.NEAREST,
45
}
46
else:
47
PIL_INTERPOLATION = {
48
"linear": PIL.Image.LINEAR,
49
"bilinear": PIL.Image.BILINEAR,
50
"bicubic": PIL.Image.BICUBIC,
51
"lanczos": PIL.Image.LANCZOS,
52
"nearest": PIL.Image.NEAREST,
53
}
54
# ------------------------------------------------------------------------------
55
56
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
58
re_attention = re.compile(
59
r"""
60
\\\(|
61
\\\)|
62
\\\[|
63
\\]|
64
\\\\|
65
\\|
66
\(|
67
\[|
68
:([+-]?[.\d]+)\)|
69
\)|
70
]|
71
[^\\()\[\]:]+|
72
:
73
""",
74
re.X,
75
)
76
77
78
def parse_prompt_attention(text):
79
"""
80
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
81
Accepted tokens are:
82
(abc) - increases attention to abc by a multiplier of 1.1
83
(abc:3.12) - increases attention to abc by a multiplier of 3.12
84
[abc] - decreases attention to abc by a multiplier of 1.1
85
\( - literal character '('
86
\[ - literal character '['
87
\) - literal character ')'
88
\] - literal character ']'
89
\\ - literal character '\'
90
anything else - just text
91
>>> parse_prompt_attention('normal text')
92
[['normal text', 1.0]]
93
>>> parse_prompt_attention('an (important) word')
94
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
95
>>> parse_prompt_attention('(unbalanced')
96
[['unbalanced', 1.1]]
97
>>> parse_prompt_attention('\(literal\]')
98
[['(literal]', 1.0]]
99
>>> parse_prompt_attention('(unnecessary)(parens)')
100
[['unnecessaryparens', 1.1]]
101
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
102
[['a ', 1.0],
103
['house', 1.5730000000000004],
104
[' ', 1.1],
105
['on', 1.0],
106
[' a ', 1.1],
107
['hill', 0.55],
108
[', sun, ', 1.1],
109
['sky', 1.4641000000000006],
110
['.', 1.1]]
111
"""
112
113
res = []
114
round_brackets = []
115
square_brackets = []
116
117
round_bracket_multiplier = 1.1
118
square_bracket_multiplier = 1 / 1.1
119
120
def multiply_range(start_position, multiplier):
121
for p in range(start_position, len(res)):
122
res[p][1] *= multiplier
123
124
for m in re_attention.finditer(text):
125
text = m.group(0)
126
weight = m.group(1)
127
128
if text.startswith("\\"):
129
res.append([text[1:], 1.0])
130
elif text == "(":
131
round_brackets.append(len(res))
132
elif text == "[":
133
square_brackets.append(len(res))
134
elif weight is not None and len(round_brackets) > 0:
135
multiply_range(round_brackets.pop(), float(weight))
136
elif text == ")" and len(round_brackets) > 0:
137
multiply_range(round_brackets.pop(), round_bracket_multiplier)
138
elif text == "]" and len(square_brackets) > 0:
139
multiply_range(square_brackets.pop(), square_bracket_multiplier)
140
else:
141
res.append([text, 1.0])
142
143
for pos in round_brackets:
144
multiply_range(pos, round_bracket_multiplier)
145
146
for pos in square_brackets:
147
multiply_range(pos, square_bracket_multiplier)
148
149
if len(res) == 0:
150
res = [["", 1.0]]
151
152
# merge runs of identical weights
153
i = 0
154
while i + 1 < len(res):
155
if res[i][1] == res[i + 1][1]:
156
res[i][0] += res[i + 1][0]
157
res.pop(i + 1)
158
else:
159
i += 1
160
161
return res
162
163
164
def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
165
r"""
166
Tokenize a list of prompts and return its tokens with weights of each token.
167
168
No padding, starting or ending token is included.
169
"""
170
tokens = []
171
weights = []
172
truncated = False
173
for text in prompt:
174
texts_and_weights = parse_prompt_attention(text)
175
text_token = []
176
text_weight = []
177
for word, weight in texts_and_weights:
178
# tokenize and discard the starting and the ending token
179
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
180
text_token += list(token)
181
# copy the weight by length of token
182
text_weight += [weight] * len(token)
183
# stop if the text is too long (longer than truncation limit)
184
if len(text_token) > max_length:
185
truncated = True
186
break
187
# truncate
188
if len(text_token) > max_length:
189
truncated = True
190
text_token = text_token[:max_length]
191
text_weight = text_weight[:max_length]
192
tokens.append(text_token)
193
weights.append(text_weight)
194
if truncated:
195
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
196
return tokens, weights
197
198
199
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
200
r"""
201
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
202
"""
203
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
204
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
205
for i in range(len(tokens)):
206
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
207
if no_boseos_middle:
208
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
209
else:
210
w = []
211
if len(weights[i]) == 0:
212
w = [1.0] * weights_length
213
else:
214
for j in range(max_embeddings_multiples):
215
w.append(1.0) # weight for starting token in this chunk
216
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
217
w.append(1.0) # weight for ending token in this chunk
218
w += [1.0] * (weights_length - len(w))
219
weights[i] = w[:]
220
221
return tokens, weights
222
223
224
def get_unweighted_text_embeddings(
225
pipe,
226
text_input: np.array,
227
chunk_length: int,
228
no_boseos_middle: Optional[bool] = True,
229
):
230
"""
231
When the length of tokens is a multiple of the capacity of the text encoder,
232
it should be split into chunks and sent to the text encoder individually.
233
"""
234
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
235
if max_embeddings_multiples > 1:
236
text_embeddings = []
237
for i in range(max_embeddings_multiples):
238
# extract the i-th chunk
239
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
240
241
# cover the head and the tail by the starting and the ending tokens
242
text_input_chunk[:, 0] = text_input[0, 0]
243
text_input_chunk[:, -1] = text_input[0, -1]
244
245
text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
246
247
if no_boseos_middle:
248
if i == 0:
249
# discard the ending token
250
text_embedding = text_embedding[:, :-1]
251
elif i == max_embeddings_multiples - 1:
252
# discard the starting token
253
text_embedding = text_embedding[:, 1:]
254
else:
255
# discard both starting and ending tokens
256
text_embedding = text_embedding[:, 1:-1]
257
258
text_embeddings.append(text_embedding)
259
text_embeddings = np.concatenate(text_embeddings, axis=1)
260
else:
261
text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
262
return text_embeddings
263
264
265
def get_weighted_text_embeddings(
266
pipe,
267
prompt: Union[str, List[str]],
268
uncond_prompt: Optional[Union[str, List[str]]] = None,
269
max_embeddings_multiples: Optional[int] = 4,
270
no_boseos_middle: Optional[bool] = False,
271
skip_parsing: Optional[bool] = False,
272
skip_weighting: Optional[bool] = False,
273
**kwargs,
274
):
275
r"""
276
Prompts can be assigned with local weights using brackets. For example,
277
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
278
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
279
280
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
281
282
Args:
283
pipe (`OnnxStableDiffusionPipeline`):
284
Pipe to provide access to the tokenizer and the text encoder.
285
prompt (`str` or `List[str]`):
286
The prompt or prompts to guide the image generation.
287
uncond_prompt (`str` or `List[str]`):
288
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
289
is provided, the embeddings of prompt and uncond_prompt are concatenated.
290
max_embeddings_multiples (`int`, *optional*, defaults to `1`):
291
The max multiple length of prompt embeddings compared to the max output length of text encoder.
292
no_boseos_middle (`bool`, *optional*, defaults to `False`):
293
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
294
ending token in each of the chunk in the middle.
295
skip_parsing (`bool`, *optional*, defaults to `False`):
296
Skip the parsing of brackets.
297
skip_weighting (`bool`, *optional*, defaults to `False`):
298
Skip the weighting. When the parsing is skipped, it is forced True.
299
"""
300
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
301
if isinstance(prompt, str):
302
prompt = [prompt]
303
304
if not skip_parsing:
305
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
306
if uncond_prompt is not None:
307
if isinstance(uncond_prompt, str):
308
uncond_prompt = [uncond_prompt]
309
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
310
else:
311
prompt_tokens = [
312
token[1:-1]
313
for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
314
]
315
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
316
if uncond_prompt is not None:
317
if isinstance(uncond_prompt, str):
318
uncond_prompt = [uncond_prompt]
319
uncond_tokens = [
320
token[1:-1]
321
for token in pipe.tokenizer(
322
uncond_prompt,
323
max_length=max_length,
324
truncation=True,
325
return_tensors="np",
326
).input_ids
327
]
328
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
329
330
# round up the longest length of tokens to a multiple of (model_max_length - 2)
331
max_length = max([len(token) for token in prompt_tokens])
332
if uncond_prompt is not None:
333
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
334
335
max_embeddings_multiples = min(
336
max_embeddings_multiples,
337
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
338
)
339
max_embeddings_multiples = max(1, max_embeddings_multiples)
340
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
341
342
# pad the length of tokens and weights
343
bos = pipe.tokenizer.bos_token_id
344
eos = pipe.tokenizer.eos_token_id
345
prompt_tokens, prompt_weights = pad_tokens_and_weights(
346
prompt_tokens,
347
prompt_weights,
348
max_length,
349
bos,
350
eos,
351
no_boseos_middle=no_boseos_middle,
352
chunk_length=pipe.tokenizer.model_max_length,
353
)
354
prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
355
if uncond_prompt is not None:
356
uncond_tokens, uncond_weights = pad_tokens_and_weights(
357
uncond_tokens,
358
uncond_weights,
359
max_length,
360
bos,
361
eos,
362
no_boseos_middle=no_boseos_middle,
363
chunk_length=pipe.tokenizer.model_max_length,
364
)
365
uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
366
367
# get the embeddings
368
text_embeddings = get_unweighted_text_embeddings(
369
pipe,
370
prompt_tokens,
371
pipe.tokenizer.model_max_length,
372
no_boseos_middle=no_boseos_middle,
373
)
374
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
375
if uncond_prompt is not None:
376
uncond_embeddings = get_unweighted_text_embeddings(
377
pipe,
378
uncond_tokens,
379
pipe.tokenizer.model_max_length,
380
no_boseos_middle=no_boseos_middle,
381
)
382
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
383
384
# assign weights to the prompts and normalize in the sense of mean
385
# TODO: should we normalize by chunk or in a whole (current implementation)?
386
if (not skip_parsing) and (not skip_weighting):
387
previous_mean = text_embeddings.mean(axis=(-2, -1))
388
text_embeddings *= prompt_weights[:, :, None]
389
text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
390
if uncond_prompt is not None:
391
previous_mean = uncond_embeddings.mean(axis=(-2, -1))
392
uncond_embeddings *= uncond_weights[:, :, None]
393
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
394
395
# For classifier free guidance, we need to do two forward passes.
396
# Here we concatenate the unconditional and text embeddings into a single batch
397
# to avoid doing two forward passes
398
if uncond_prompt is not None:
399
return text_embeddings, uncond_embeddings
400
401
return text_embeddings
402
403
404
def preprocess_image(image):
405
w, h = image.size
406
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
407
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
408
image = np.array(image).astype(np.float32) / 255.0
409
image = image[None].transpose(0, 3, 1, 2)
410
return 2.0 * image - 1.0
411
412
413
def preprocess_mask(mask, scale_factor=8):
414
mask = mask.convert("L")
415
w, h = mask.size
416
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
417
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
418
mask = np.array(mask).astype(np.float32) / 255.0
419
mask = np.tile(mask, (4, 1, 1))
420
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
421
mask = 1 - mask # repaint white, keep black
422
return mask
423
424
425
class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
426
r"""
427
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
428
weighting in prompt.
429
430
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
431
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
432
"""
433
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
434
435
def __init__(
436
self,
437
vae_encoder: OnnxRuntimeModel,
438
vae_decoder: OnnxRuntimeModel,
439
text_encoder: OnnxRuntimeModel,
440
tokenizer: CLIPTokenizer,
441
unet: OnnxRuntimeModel,
442
scheduler: SchedulerMixin,
443
safety_checker: OnnxRuntimeModel,
444
feature_extractor: CLIPImageProcessor,
445
requires_safety_checker: bool = True,
446
):
447
super().__init__(
448
vae_encoder=vae_encoder,
449
vae_decoder=vae_decoder,
450
text_encoder=text_encoder,
451
tokenizer=tokenizer,
452
unet=unet,
453
scheduler=scheduler,
454
safety_checker=safety_checker,
455
feature_extractor=feature_extractor,
456
requires_safety_checker=requires_safety_checker,
457
)
458
self.__init__additional__()
459
460
else:
461
462
def __init__(
463
self,
464
vae_encoder: OnnxRuntimeModel,
465
vae_decoder: OnnxRuntimeModel,
466
text_encoder: OnnxRuntimeModel,
467
tokenizer: CLIPTokenizer,
468
unet: OnnxRuntimeModel,
469
scheduler: SchedulerMixin,
470
safety_checker: OnnxRuntimeModel,
471
feature_extractor: CLIPImageProcessor,
472
):
473
super().__init__(
474
vae_encoder=vae_encoder,
475
vae_decoder=vae_decoder,
476
text_encoder=text_encoder,
477
tokenizer=tokenizer,
478
unet=unet,
479
scheduler=scheduler,
480
safety_checker=safety_checker,
481
feature_extractor=feature_extractor,
482
)
483
self.__init__additional__()
484
485
def __init__additional__(self):
486
self.unet_in_channels = 4
487
self.vae_scale_factor = 8
488
489
def _encode_prompt(
490
self,
491
prompt,
492
num_images_per_prompt,
493
do_classifier_free_guidance,
494
negative_prompt,
495
max_embeddings_multiples,
496
):
497
r"""
498
Encodes the prompt into text encoder hidden states.
499
500
Args:
501
prompt (`str` or `list(int)`):
502
prompt to be encoded
503
num_images_per_prompt (`int`):
504
number of images that should be generated per prompt
505
do_classifier_free_guidance (`bool`):
506
whether to use classifier free guidance or not
507
negative_prompt (`str` or `List[str]`):
508
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
509
if `guidance_scale` is less than `1`).
510
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
511
The max multiple length of prompt embeddings compared to the max output length of text encoder.
512
"""
513
batch_size = len(prompt) if isinstance(prompt, list) else 1
514
515
if negative_prompt is None:
516
negative_prompt = [""] * batch_size
517
elif isinstance(negative_prompt, str):
518
negative_prompt = [negative_prompt] * batch_size
519
if batch_size != len(negative_prompt):
520
raise ValueError(
521
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
522
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
523
" the batch size of `prompt`."
524
)
525
526
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
527
pipe=self,
528
prompt=prompt,
529
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
530
max_embeddings_multiples=max_embeddings_multiples,
531
)
532
533
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
534
if do_classifier_free_guidance:
535
uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
536
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
537
538
return text_embeddings
539
540
def check_inputs(self, prompt, height, width, strength, callback_steps):
541
if not isinstance(prompt, str) and not isinstance(prompt, list):
542
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
543
544
if strength < 0 or strength > 1:
545
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
546
547
if height % 8 != 0 or width % 8 != 0:
548
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
549
550
if (callback_steps is None) or (
551
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
552
):
553
raise ValueError(
554
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
555
f" {type(callback_steps)}."
556
)
557
558
def get_timesteps(self, num_inference_steps, strength, is_text2img):
559
if is_text2img:
560
return self.scheduler.timesteps, num_inference_steps
561
else:
562
# get the original timestep using init_timestep
563
offset = self.scheduler.config.get("steps_offset", 0)
564
init_timestep = int(num_inference_steps * strength) + offset
565
init_timestep = min(init_timestep, num_inference_steps)
566
567
t_start = max(num_inference_steps - init_timestep + offset, 0)
568
timesteps = self.scheduler.timesteps[t_start:]
569
return timesteps, num_inference_steps - t_start
570
571
def run_safety_checker(self, image):
572
if self.safety_checker is not None:
573
safety_checker_input = self.feature_extractor(
574
self.numpy_to_pil(image), return_tensors="np"
575
).pixel_values.astype(image.dtype)
576
# There will throw an error if use safety_checker directly and batchsize>1
577
images, has_nsfw_concept = [], []
578
for i in range(image.shape[0]):
579
image_i, has_nsfw_concept_i = self.safety_checker(
580
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
581
)
582
images.append(image_i)
583
has_nsfw_concept.append(has_nsfw_concept_i[0])
584
image = np.concatenate(images)
585
else:
586
has_nsfw_concept = None
587
return image, has_nsfw_concept
588
589
def decode_latents(self, latents):
590
latents = 1 / 0.18215 * latents
591
# image = self.vae_decoder(latent_sample=latents)[0]
592
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
593
image = np.concatenate(
594
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
595
)
596
image = np.clip(image / 2 + 0.5, 0, 1)
597
image = image.transpose((0, 2, 3, 1))
598
return image
599
600
def prepare_extra_step_kwargs(self, generator, eta):
601
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
602
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
603
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
604
# and should be between [0, 1]
605
606
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
607
extra_step_kwargs = {}
608
if accepts_eta:
609
extra_step_kwargs["eta"] = eta
610
611
# check if the scheduler accepts generator
612
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
613
if accepts_generator:
614
extra_step_kwargs["generator"] = generator
615
return extra_step_kwargs
616
617
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
618
if image is None:
619
shape = (
620
batch_size,
621
self.unet_in_channels,
622
height // self.vae_scale_factor,
623
width // self.vae_scale_factor,
624
)
625
626
if latents is None:
627
latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
628
else:
629
if latents.shape != shape:
630
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
631
632
# scale the initial noise by the standard deviation required by the scheduler
633
latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
634
return latents, None, None
635
else:
636
init_latents = self.vae_encoder(sample=image)[0]
637
init_latents = 0.18215 * init_latents
638
init_latents = np.concatenate([init_latents] * batch_size, axis=0)
639
init_latents_orig = init_latents
640
shape = init_latents.shape
641
642
# add noise to latents using the timesteps
643
noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
644
latents = self.scheduler.add_noise(
645
torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
646
).numpy()
647
return latents, init_latents_orig, noise
648
649
@torch.no_grad()
650
def __call__(
651
self,
652
prompt: Union[str, List[str]],
653
negative_prompt: Optional[Union[str, List[str]]] = None,
654
image: Union[np.ndarray, PIL.Image.Image] = None,
655
mask_image: Union[np.ndarray, PIL.Image.Image] = None,
656
height: int = 512,
657
width: int = 512,
658
num_inference_steps: int = 50,
659
guidance_scale: float = 7.5,
660
strength: float = 0.8,
661
num_images_per_prompt: Optional[int] = 1,
662
eta: float = 0.0,
663
generator: Optional[torch.Generator] = None,
664
latents: Optional[np.ndarray] = None,
665
max_embeddings_multiples: Optional[int] = 3,
666
output_type: Optional[str] = "pil",
667
return_dict: bool = True,
668
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
669
is_cancelled_callback: Optional[Callable[[], bool]] = None,
670
callback_steps: int = 1,
671
**kwargs,
672
):
673
r"""
674
Function invoked when calling the pipeline for generation.
675
676
Args:
677
prompt (`str` or `List[str]`):
678
The prompt or prompts to guide the image generation.
679
negative_prompt (`str` or `List[str]`, *optional*):
680
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
681
if `guidance_scale` is less than `1`).
682
image (`np.ndarray` or `PIL.Image.Image`):
683
`Image`, or tensor representing an image batch, that will be used as the starting point for the
684
process.
685
mask_image (`np.ndarray` or `PIL.Image.Image`):
686
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
687
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
688
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
689
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
690
height (`int`, *optional*, defaults to 512):
691
The height in pixels of the generated image.
692
width (`int`, *optional*, defaults to 512):
693
The width in pixels of the generated image.
694
num_inference_steps (`int`, *optional*, defaults to 50):
695
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
696
expense of slower inference.
697
guidance_scale (`float`, *optional*, defaults to 7.5):
698
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
699
`guidance_scale` is defined as `w` of equation 2. of [Imagen
700
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
701
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
702
usually at the expense of lower image quality.
703
strength (`float`, *optional*, defaults to 0.8):
704
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
705
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
706
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
707
noise will be maximum and the denoising process will run for the full number of iterations specified in
708
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
709
num_images_per_prompt (`int`, *optional*, defaults to 1):
710
The number of images to generate per prompt.
711
eta (`float`, *optional*, defaults to 0.0):
712
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
713
[`schedulers.DDIMScheduler`], will be ignored for others.
714
generator (`torch.Generator`, *optional*):
715
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
716
deterministic.
717
latents (`np.ndarray`, *optional*):
718
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
719
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
720
tensor will ge generated by sampling using the supplied random `generator`.
721
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
722
The max multiple length of prompt embeddings compared to the max output length of text encoder.
723
output_type (`str`, *optional*, defaults to `"pil"`):
724
The output format of the generate image. Choose between
725
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
726
return_dict (`bool`, *optional*, defaults to `True`):
727
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
728
plain tuple.
729
callback (`Callable`, *optional*):
730
A function that will be called every `callback_steps` steps during inference. The function will be
731
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
732
is_cancelled_callback (`Callable`, *optional*):
733
A function that will be called every `callback_steps` steps during inference. If the function returns
734
`True`, the inference will be cancelled.
735
callback_steps (`int`, *optional*, defaults to 1):
736
The frequency at which the `callback` function will be called. If not specified, the callback will be
737
called at every step.
738
739
Returns:
740
`None` if cancelled by `is_cancelled_callback`,
741
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
742
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
743
When returning a tuple, the first element is a list with the generated images, and the second element is a
744
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
745
(nsfw) content, according to the `safety_checker`.
746
"""
747
# 0. Default height and width to unet
748
height = height or self.unet.config.sample_size * self.vae_scale_factor
749
width = width or self.unet.config.sample_size * self.vae_scale_factor
750
751
# 1. Check inputs. Raise error if not correct
752
self.check_inputs(prompt, height, width, strength, callback_steps)
753
754
# 2. Define call parameters
755
batch_size = 1 if isinstance(prompt, str) else len(prompt)
756
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
757
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
758
# corresponds to doing no classifier free guidance.
759
do_classifier_free_guidance = guidance_scale > 1.0
760
761
# 3. Encode input prompt
762
text_embeddings = self._encode_prompt(
763
prompt,
764
num_images_per_prompt,
765
do_classifier_free_guidance,
766
negative_prompt,
767
max_embeddings_multiples,
768
)
769
dtype = text_embeddings.dtype
770
771
# 4. Preprocess image and mask
772
if isinstance(image, PIL.Image.Image):
773
image = preprocess_image(image)
774
if image is not None:
775
image = image.astype(dtype)
776
if isinstance(mask_image, PIL.Image.Image):
777
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
778
if mask_image is not None:
779
mask = mask_image.astype(dtype)
780
mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
781
else:
782
mask = None
783
784
# 5. set timesteps
785
self.scheduler.set_timesteps(num_inference_steps)
786
timestep_dtype = next(
787
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
788
)
789
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
790
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
791
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
792
793
# 6. Prepare latent variables
794
latents, init_latents_orig, noise = self.prepare_latents(
795
image,
796
latent_timestep,
797
batch_size * num_images_per_prompt,
798
height,
799
width,
800
dtype,
801
generator,
802
latents,
803
)
804
805
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
806
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
807
808
# 8. Denoising loop
809
for i, t in enumerate(self.progress_bar(timesteps)):
810
# expand the latents if we are doing classifier free guidance
811
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
812
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
813
latent_model_input = latent_model_input.numpy()
814
815
# predict the noise residual
816
noise_pred = self.unet(
817
sample=latent_model_input,
818
timestep=np.array([t], dtype=timestep_dtype),
819
encoder_hidden_states=text_embeddings,
820
)
821
noise_pred = noise_pred[0]
822
823
# perform guidance
824
if do_classifier_free_guidance:
825
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
826
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
827
828
# compute the previous noisy sample x_t -> x_t-1
829
scheduler_output = self.scheduler.step(
830
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
831
)
832
latents = scheduler_output.prev_sample.numpy()
833
834
if mask is not None:
835
# masking
836
init_latents_proper = self.scheduler.add_noise(
837
torch.from_numpy(init_latents_orig),
838
torch.from_numpy(noise),
839
t,
840
).numpy()
841
latents = (init_latents_proper * mask) + (latents * (1 - mask))
842
843
# call the callback, if provided
844
if i % callback_steps == 0:
845
if callback is not None:
846
callback(i, t, latents)
847
if is_cancelled_callback is not None and is_cancelled_callback():
848
return None
849
850
# 9. Post-processing
851
image = self.decode_latents(latents)
852
853
# 10. Run safety checker
854
image, has_nsfw_concept = self.run_safety_checker(image)
855
856
# 11. Convert to PIL
857
if output_type == "pil":
858
image = self.numpy_to_pil(image)
859
860
if not return_dict:
861
return image, has_nsfw_concept
862
863
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
864
865
def text2img(
866
self,
867
prompt: Union[str, List[str]],
868
negative_prompt: Optional[Union[str, List[str]]] = None,
869
height: int = 512,
870
width: int = 512,
871
num_inference_steps: int = 50,
872
guidance_scale: float = 7.5,
873
num_images_per_prompt: Optional[int] = 1,
874
eta: float = 0.0,
875
generator: Optional[torch.Generator] = None,
876
latents: Optional[np.ndarray] = None,
877
max_embeddings_multiples: Optional[int] = 3,
878
output_type: Optional[str] = "pil",
879
return_dict: bool = True,
880
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
881
callback_steps: int = 1,
882
**kwargs,
883
):
884
r"""
885
Function for text-to-image generation.
886
Args:
887
prompt (`str` or `List[str]`):
888
The prompt or prompts to guide the image generation.
889
negative_prompt (`str` or `List[str]`, *optional*):
890
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
891
if `guidance_scale` is less than `1`).
892
height (`int`, *optional*, defaults to 512):
893
The height in pixels of the generated image.
894
width (`int`, *optional*, defaults to 512):
895
The width in pixels of the generated image.
896
num_inference_steps (`int`, *optional*, defaults to 50):
897
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
898
expense of slower inference.
899
guidance_scale (`float`, *optional*, defaults to 7.5):
900
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
901
`guidance_scale` is defined as `w` of equation 2. of [Imagen
902
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
903
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
904
usually at the expense of lower image quality.
905
num_images_per_prompt (`int`, *optional*, defaults to 1):
906
The number of images to generate per prompt.
907
eta (`float`, *optional*, defaults to 0.0):
908
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
909
[`schedulers.DDIMScheduler`], will be ignored for others.
910
generator (`torch.Generator`, *optional*):
911
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
912
deterministic.
913
latents (`np.ndarray`, *optional*):
914
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
915
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
916
tensor will ge generated by sampling using the supplied random `generator`.
917
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
918
The max multiple length of prompt embeddings compared to the max output length of text encoder.
919
output_type (`str`, *optional*, defaults to `"pil"`):
920
The output format of the generate image. Choose between
921
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
922
return_dict (`bool`, *optional*, defaults to `True`):
923
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
924
plain tuple.
925
callback (`Callable`, *optional*):
926
A function that will be called every `callback_steps` steps during inference. The function will be
927
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
928
callback_steps (`int`, *optional*, defaults to 1):
929
The frequency at which the `callback` function will be called. If not specified, the callback will be
930
called at every step.
931
Returns:
932
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
933
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
934
When returning a tuple, the first element is a list with the generated images, and the second element is a
935
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
936
(nsfw) content, according to the `safety_checker`.
937
"""
938
return self.__call__(
939
prompt=prompt,
940
negative_prompt=negative_prompt,
941
height=height,
942
width=width,
943
num_inference_steps=num_inference_steps,
944
guidance_scale=guidance_scale,
945
num_images_per_prompt=num_images_per_prompt,
946
eta=eta,
947
generator=generator,
948
latents=latents,
949
max_embeddings_multiples=max_embeddings_multiples,
950
output_type=output_type,
951
return_dict=return_dict,
952
callback=callback,
953
callback_steps=callback_steps,
954
**kwargs,
955
)
956
957
def img2img(
958
self,
959
image: Union[np.ndarray, PIL.Image.Image],
960
prompt: Union[str, List[str]],
961
negative_prompt: Optional[Union[str, List[str]]] = None,
962
strength: float = 0.8,
963
num_inference_steps: Optional[int] = 50,
964
guidance_scale: Optional[float] = 7.5,
965
num_images_per_prompt: Optional[int] = 1,
966
eta: Optional[float] = 0.0,
967
generator: Optional[torch.Generator] = None,
968
max_embeddings_multiples: Optional[int] = 3,
969
output_type: Optional[str] = "pil",
970
return_dict: bool = True,
971
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
972
callback_steps: int = 1,
973
**kwargs,
974
):
975
r"""
976
Function for image-to-image generation.
977
Args:
978
image (`np.ndarray` or `PIL.Image.Image`):
979
`Image`, or ndarray representing an image batch, that will be used as the starting point for the
980
process.
981
prompt (`str` or `List[str]`):
982
The prompt or prompts to guide the image generation.
983
negative_prompt (`str` or `List[str]`, *optional*):
984
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
985
if `guidance_scale` is less than `1`).
986
strength (`float`, *optional*, defaults to 0.8):
987
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
988
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
989
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
990
noise will be maximum and the denoising process will run for the full number of iterations specified in
991
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
992
num_inference_steps (`int`, *optional*, defaults to 50):
993
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
994
expense of slower inference. This parameter will be modulated by `strength`.
995
guidance_scale (`float`, *optional*, defaults to 7.5):
996
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
997
`guidance_scale` is defined as `w` of equation 2. of [Imagen
998
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
999
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1000
usually at the expense of lower image quality.
1001
num_images_per_prompt (`int`, *optional*, defaults to 1):
1002
The number of images to generate per prompt.
1003
eta (`float`, *optional*, defaults to 0.0):
1004
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1005
[`schedulers.DDIMScheduler`], will be ignored for others.
1006
generator (`torch.Generator`, *optional*):
1007
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1008
deterministic.
1009
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1010
The max multiple length of prompt embeddings compared to the max output length of text encoder.
1011
output_type (`str`, *optional*, defaults to `"pil"`):
1012
The output format of the generate image. Choose between
1013
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1014
return_dict (`bool`, *optional*, defaults to `True`):
1015
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1016
plain tuple.
1017
callback (`Callable`, *optional*):
1018
A function that will be called every `callback_steps` steps during inference. The function will be
1019
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
1020
callback_steps (`int`, *optional*, defaults to 1):
1021
The frequency at which the `callback` function will be called. If not specified, the callback will be
1022
called at every step.
1023
Returns:
1024
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1025
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1026
When returning a tuple, the first element is a list with the generated images, and the second element is a
1027
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1028
(nsfw) content, according to the `safety_checker`.
1029
"""
1030
return self.__call__(
1031
prompt=prompt,
1032
negative_prompt=negative_prompt,
1033
image=image,
1034
num_inference_steps=num_inference_steps,
1035
guidance_scale=guidance_scale,
1036
strength=strength,
1037
num_images_per_prompt=num_images_per_prompt,
1038
eta=eta,
1039
generator=generator,
1040
max_embeddings_multiples=max_embeddings_multiples,
1041
output_type=output_type,
1042
return_dict=return_dict,
1043
callback=callback,
1044
callback_steps=callback_steps,
1045
**kwargs,
1046
)
1047
1048
def inpaint(
1049
self,
1050
image: Union[np.ndarray, PIL.Image.Image],
1051
mask_image: Union[np.ndarray, PIL.Image.Image],
1052
prompt: Union[str, List[str]],
1053
negative_prompt: Optional[Union[str, List[str]]] = None,
1054
strength: float = 0.8,
1055
num_inference_steps: Optional[int] = 50,
1056
guidance_scale: Optional[float] = 7.5,
1057
num_images_per_prompt: Optional[int] = 1,
1058
eta: Optional[float] = 0.0,
1059
generator: Optional[torch.Generator] = None,
1060
max_embeddings_multiples: Optional[int] = 3,
1061
output_type: Optional[str] = "pil",
1062
return_dict: bool = True,
1063
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
1064
callback_steps: int = 1,
1065
**kwargs,
1066
):
1067
r"""
1068
Function for inpaint.
1069
Args:
1070
image (`np.ndarray` or `PIL.Image.Image`):
1071
`Image`, or tensor representing an image batch, that will be used as the starting point for the
1072
process. This is the image whose masked region will be inpainted.
1073
mask_image (`np.ndarray` or `PIL.Image.Image`):
1074
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1075
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1076
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1077
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1078
prompt (`str` or `List[str]`):
1079
The prompt or prompts to guide the image generation.
1080
negative_prompt (`str` or `List[str]`, *optional*):
1081
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1082
if `guidance_scale` is less than `1`).
1083
strength (`float`, *optional*, defaults to 0.8):
1084
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1085
is 1, the denoising process will be run on the masked area for the full number of iterations specified
1086
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1087
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1088
num_inference_steps (`int`, *optional*, defaults to 50):
1089
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1090
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1091
guidance_scale (`float`, *optional*, defaults to 7.5):
1092
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1093
`guidance_scale` is defined as `w` of equation 2. of [Imagen
1094
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1095
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1096
usually at the expense of lower image quality.
1097
num_images_per_prompt (`int`, *optional*, defaults to 1):
1098
The number of images to generate per prompt.
1099
eta (`float`, *optional*, defaults to 0.0):
1100
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1101
[`schedulers.DDIMScheduler`], will be ignored for others.
1102
generator (`torch.Generator`, *optional*):
1103
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1104
deterministic.
1105
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1106
The max multiple length of prompt embeddings compared to the max output length of text encoder.
1107
output_type (`str`, *optional*, defaults to `"pil"`):
1108
The output format of the generate image. Choose between
1109
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1110
return_dict (`bool`, *optional*, defaults to `True`):
1111
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1112
plain tuple.
1113
callback (`Callable`, *optional*):
1114
A function that will be called every `callback_steps` steps during inference. The function will be
1115
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
1116
callback_steps (`int`, *optional*, defaults to 1):
1117
The frequency at which the `callback` function will be called. If not specified, the callback will be
1118
called at every step.
1119
Returns:
1120
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1121
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1122
When returning a tuple, the first element is a list with the generated images, and the second element is a
1123
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1124
(nsfw) content, according to the `safety_checker`.
1125
"""
1126
return self.__call__(
1127
prompt=prompt,
1128
negative_prompt=negative_prompt,
1129
image=image,
1130
mask_image=mask_image,
1131
num_inference_steps=num_inference_steps,
1132
guidance_scale=guidance_scale,
1133
strength=strength,
1134
num_images_per_prompt=num_images_per_prompt,
1135
eta=eta,
1136
generator=generator,
1137
max_embeddings_multiples=max_embeddings_multiples,
1138
output_type=output_type,
1139
return_dict=return_dict,
1140
callback=callback,
1141
callback_steps=callback_steps,
1142
**kwargs,
1143
)
1144
1145