Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/multilingual_stable_diffusion.py
1448 views
1
import inspect
2
from typing import Callable, List, Optional, Union
3
4
import torch
5
from transformers import (
6
CLIPImageProcessor,
7
CLIPTextModel,
8
CLIPTokenizer,
9
MBart50TokenizerFast,
10
MBartForConditionalGeneration,
11
pipeline,
12
)
13
14
from diffusers import DiffusionPipeline
15
from diffusers.configuration_utils import FrozenDict
16
from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
18
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
19
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
20
from diffusers.utils import deprecate, logging
21
22
23
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
25
26
def detect_language(pipe, prompt, batch_size):
27
"""helper function to detect language(s) of prompt"""
28
29
if batch_size == 1:
30
preds = pipe(prompt, top_k=1, truncation=True, max_length=128)
31
return preds[0]["label"]
32
else:
33
detected_languages = []
34
for p in prompt:
35
preds = pipe(p, top_k=1, truncation=True, max_length=128)
36
detected_languages.append(preds[0]["label"])
37
38
return detected_languages
39
40
41
def translate_prompt(prompt, translation_tokenizer, translation_model, device):
42
"""helper function to translate prompt to English"""
43
44
encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)
45
generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)
46
en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
47
48
return en_trans[0]
49
50
51
class MultilingualStableDiffusion(DiffusionPipeline):
52
r"""
53
Pipeline for text-to-image generation using Stable Diffusion in different languages.
54
55
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
56
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
57
58
Args:
59
detection_pipeline ([`pipeline`]):
60
Transformers pipeline to detect prompt's language.
61
translation_model ([`MBartForConditionalGeneration`]):
62
Model to translate prompt to English, if necessary. Please refer to the
63
[model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.
64
translation_tokenizer ([`MBart50TokenizerFast`]):
65
Tokenizer of the translation model.
66
vae ([`AutoencoderKL`]):
67
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68
text_encoder ([`CLIPTextModel`]):
69
Frozen text-encoder. Stable Diffusion uses the text portion of
70
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
71
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
72
tokenizer (`CLIPTokenizer`):
73
Tokenizer of class
74
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
75
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
76
scheduler ([`SchedulerMixin`]):
77
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
78
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
79
safety_checker ([`StableDiffusionSafetyChecker`]):
80
Classification module that estimates whether generated images could be considered offensive or harmful.
81
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
82
feature_extractor ([`CLIPImageProcessor`]):
83
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
84
"""
85
86
def __init__(
87
self,
88
detection_pipeline: pipeline,
89
translation_model: MBartForConditionalGeneration,
90
translation_tokenizer: MBart50TokenizerFast,
91
vae: AutoencoderKL,
92
text_encoder: CLIPTextModel,
93
tokenizer: CLIPTokenizer,
94
unet: UNet2DConditionModel,
95
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
96
safety_checker: StableDiffusionSafetyChecker,
97
feature_extractor: CLIPImageProcessor,
98
):
99
super().__init__()
100
101
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
102
deprecation_message = (
103
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
104
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
105
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
106
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
107
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
108
" file"
109
)
110
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
111
new_config = dict(scheduler.config)
112
new_config["steps_offset"] = 1
113
scheduler._internal_dict = FrozenDict(new_config)
114
115
if safety_checker is None:
116
logger.warning(
117
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
118
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
119
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
120
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
121
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
122
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
123
)
124
125
self.register_modules(
126
detection_pipeline=detection_pipeline,
127
translation_model=translation_model,
128
translation_tokenizer=translation_tokenizer,
129
vae=vae,
130
text_encoder=text_encoder,
131
tokenizer=tokenizer,
132
unet=unet,
133
scheduler=scheduler,
134
safety_checker=safety_checker,
135
feature_extractor=feature_extractor,
136
)
137
138
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
139
r"""
140
Enable sliced attention computation.
141
142
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
143
in several steps. This is useful to save some memory in exchange for a small speed decrease.
144
145
Args:
146
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
147
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
148
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
149
`attention_head_dim` must be a multiple of `slice_size`.
150
"""
151
if slice_size == "auto":
152
# half the attention head size is usually a good trade-off between
153
# speed and memory
154
slice_size = self.unet.config.attention_head_dim // 2
155
self.unet.set_attention_slice(slice_size)
156
157
def disable_attention_slicing(self):
158
r"""
159
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
160
back to computing attention in one step.
161
"""
162
# set slice_size = `None` to disable `attention slicing`
163
self.enable_attention_slicing(None)
164
165
@torch.no_grad()
166
def __call__(
167
self,
168
prompt: Union[str, List[str]],
169
height: int = 512,
170
width: int = 512,
171
num_inference_steps: int = 50,
172
guidance_scale: float = 7.5,
173
negative_prompt: Optional[Union[str, List[str]]] = None,
174
num_images_per_prompt: Optional[int] = 1,
175
eta: float = 0.0,
176
generator: Optional[torch.Generator] = None,
177
latents: Optional[torch.FloatTensor] = None,
178
output_type: Optional[str] = "pil",
179
return_dict: bool = True,
180
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
181
callback_steps: int = 1,
182
**kwargs,
183
):
184
r"""
185
Function invoked when calling the pipeline for generation.
186
187
Args:
188
prompt (`str` or `List[str]`):
189
The prompt or prompts to guide the image generation. Can be in different languages.
190
height (`int`, *optional*, defaults to 512):
191
The height in pixels of the generated image.
192
width (`int`, *optional*, defaults to 512):
193
The width in pixels of the generated image.
194
num_inference_steps (`int`, *optional*, defaults to 50):
195
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
196
expense of slower inference.
197
guidance_scale (`float`, *optional*, defaults to 7.5):
198
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
199
`guidance_scale` is defined as `w` of equation 2. of [Imagen
200
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
201
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
202
usually at the expense of lower image quality.
203
negative_prompt (`str` or `List[str]`, *optional*):
204
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
205
if `guidance_scale` is less than `1`).
206
num_images_per_prompt (`int`, *optional*, defaults to 1):
207
The number of images to generate per prompt.
208
eta (`float`, *optional*, defaults to 0.0):
209
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
210
[`schedulers.DDIMScheduler`], will be ignored for others.
211
generator (`torch.Generator`, *optional*):
212
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
213
deterministic.
214
latents (`torch.FloatTensor`, *optional*):
215
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
216
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
217
tensor will ge generated by sampling using the supplied random `generator`.
218
output_type (`str`, *optional*, defaults to `"pil"`):
219
The output format of the generate image. Choose between
220
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
221
return_dict (`bool`, *optional*, defaults to `True`):
222
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
223
plain tuple.
224
callback (`Callable`, *optional*):
225
A function that will be called every `callback_steps` steps during inference. The function will be
226
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
227
callback_steps (`int`, *optional*, defaults to 1):
228
The frequency at which the `callback` function will be called. If not specified, the callback will be
229
called at every step.
230
231
Returns:
232
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
233
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
234
When returning a tuple, the first element is a list with the generated images, and the second element is a
235
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
236
(nsfw) content, according to the `safety_checker`.
237
"""
238
if isinstance(prompt, str):
239
batch_size = 1
240
elif isinstance(prompt, list):
241
batch_size = len(prompt)
242
else:
243
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
244
245
if height % 8 != 0 or width % 8 != 0:
246
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
247
248
if (callback_steps is None) or (
249
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
250
):
251
raise ValueError(
252
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
253
f" {type(callback_steps)}."
254
)
255
256
# detect language and translate if necessary
257
prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)
258
if batch_size == 1 and prompt_language != "en":
259
prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)
260
261
if isinstance(prompt, list):
262
for index in range(batch_size):
263
if prompt_language[index] != "en":
264
p = translate_prompt(
265
prompt[index], self.translation_tokenizer, self.translation_model, self.device
266
)
267
prompt[index] = p
268
269
# get prompt text embeddings
270
text_inputs = self.tokenizer(
271
prompt,
272
padding="max_length",
273
max_length=self.tokenizer.model_max_length,
274
return_tensors="pt",
275
)
276
text_input_ids = text_inputs.input_ids
277
278
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
279
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
280
logger.warning(
281
"The following part of your input was truncated because CLIP can only handle sequences up to"
282
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
283
)
284
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
285
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
286
287
# duplicate text embeddings for each generation per prompt, using mps friendly method
288
bs_embed, seq_len, _ = text_embeddings.shape
289
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
290
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
291
292
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
293
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
294
# corresponds to doing no classifier free guidance.
295
do_classifier_free_guidance = guidance_scale > 1.0
296
# get unconditional embeddings for classifier free guidance
297
if do_classifier_free_guidance:
298
uncond_tokens: List[str]
299
if negative_prompt is None:
300
uncond_tokens = [""] * batch_size
301
elif type(prompt) is not type(negative_prompt):
302
raise TypeError(
303
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
304
f" {type(prompt)}."
305
)
306
elif isinstance(negative_prompt, str):
307
# detect language and translate it if necessary
308
negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)
309
if negative_prompt_language != "en":
310
negative_prompt = translate_prompt(
311
negative_prompt, self.translation_tokenizer, self.translation_model, self.device
312
)
313
if isinstance(negative_prompt, str):
314
uncond_tokens = [negative_prompt]
315
elif batch_size != len(negative_prompt):
316
raise ValueError(
317
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
318
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
319
" the batch size of `prompt`."
320
)
321
else:
322
# detect language and translate it if necessary
323
if isinstance(negative_prompt, list):
324
negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)
325
for index in range(batch_size):
326
if negative_prompt_languages[index] != "en":
327
p = translate_prompt(
328
negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device
329
)
330
negative_prompt[index] = p
331
uncond_tokens = negative_prompt
332
333
max_length = text_input_ids.shape[-1]
334
uncond_input = self.tokenizer(
335
uncond_tokens,
336
padding="max_length",
337
max_length=max_length,
338
truncation=True,
339
return_tensors="pt",
340
)
341
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
342
343
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
344
seq_len = uncond_embeddings.shape[1]
345
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
346
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
347
348
# For classifier free guidance, we need to do two forward passes.
349
# Here we concatenate the unconditional and text embeddings into a single batch
350
# to avoid doing two forward passes
351
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
352
353
# get the initial random noise unless the user supplied it
354
355
# Unlike in other pipelines, latents need to be generated in the target device
356
# for 1-to-1 results reproducibility with the CompVis implementation.
357
# However this currently doesn't work in `mps`.
358
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
359
latents_dtype = text_embeddings.dtype
360
if latents is None:
361
if self.device.type == "mps":
362
# randn does not work reproducibly on mps
363
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
364
self.device
365
)
366
else:
367
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
368
else:
369
if latents.shape != latents_shape:
370
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
371
latents = latents.to(self.device)
372
373
# set timesteps
374
self.scheduler.set_timesteps(num_inference_steps)
375
376
# Some schedulers like PNDM have timesteps as arrays
377
# It's more optimized to move all timesteps to correct device beforehand
378
timesteps_tensor = self.scheduler.timesteps.to(self.device)
379
380
# scale the initial noise by the standard deviation required by the scheduler
381
latents = latents * self.scheduler.init_noise_sigma
382
383
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
384
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
385
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
386
# and should be between [0, 1]
387
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
388
extra_step_kwargs = {}
389
if accepts_eta:
390
extra_step_kwargs["eta"] = eta
391
392
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
393
# expand the latents if we are doing classifier free guidance
394
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
395
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
396
397
# predict the noise residual
398
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
399
400
# perform guidance
401
if do_classifier_free_guidance:
402
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
403
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
404
405
# compute the previous noisy sample x_t -> x_t-1
406
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
407
408
# call the callback, if provided
409
if callback is not None and i % callback_steps == 0:
410
callback(i, t, latents)
411
412
latents = 1 / 0.18215 * latents
413
image = self.vae.decode(latents).sample
414
415
image = (image / 2 + 0.5).clamp(0, 1)
416
417
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
418
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
419
420
if self.safety_checker is not None:
421
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
422
self.device
423
)
424
image, has_nsfw_concept = self.safety_checker(
425
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
426
)
427
else:
428
has_nsfw_concept = None
429
430
if output_type == "pil":
431
image = self.numpy_to_pil(image)
432
433
if not return_dict:
434
return (image, has_nsfw_concept)
435
436
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
437
438