Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/wildcard_stable_diffusion.py
1448 views
1
import inspect
2
import os
3
import random
4
import re
5
from dataclasses import dataclass
6
from typing import Callable, Dict, List, Optional, Union
7
8
import torch
9
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
11
from diffusers import DiffusionPipeline
12
from diffusers.configuration_utils import FrozenDict
13
from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
15
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
16
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
17
from diffusers.utils import deprecate, logging
18
19
20
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
22
global_re_wildcard = re.compile(r"__([^_]*)__")
23
24
25
def get_filename(path: str):
26
# this doesn't work on Windows
27
return os.path.basename(path).split(".txt")[0]
28
29
30
def read_wildcard_values(path: str):
31
with open(path, encoding="utf8") as f:
32
return f.read().splitlines()
33
34
35
def grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []):
36
for wildcard_file in wildcard_files:
37
filename = get_filename(wildcard_file)
38
read_values = read_wildcard_values(wildcard_file)
39
if filename not in wildcard_option_dict:
40
wildcard_option_dict[filename] = []
41
wildcard_option_dict[filename].extend(read_values)
42
return wildcard_option_dict
43
44
45
def replace_prompt_with_wildcards(
46
prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []
47
):
48
new_prompt = prompt
49
50
# get wildcard options
51
wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files)
52
53
for m in global_re_wildcard.finditer(new_prompt):
54
wildcard_value = m.group()
55
replace_value = random.choice(wildcard_option_dict[wildcard_value.strip("__")])
56
new_prompt = new_prompt.replace(wildcard_value, replace_value, 1)
57
58
return new_prompt
59
60
61
@dataclass
62
class WildcardStableDiffusionOutput(StableDiffusionPipelineOutput):
63
prompts: List[str]
64
65
66
class WildcardStableDiffusionPipeline(DiffusionPipeline):
67
r"""
68
Example Usage:
69
pipe = WildcardStableDiffusionPipeline.from_pretrained(
70
"CompVis/stable-diffusion-v1-4",
71
72
torch_dtype=torch.float16,
73
)
74
prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
75
out = pipe(
76
prompt,
77
wildcard_option_dict={
78
"clothing":["hat", "shirt", "scarf", "beret"]
79
},
80
wildcard_files=["object.txt", "animal.txt"],
81
num_prompt_samples=1
82
)
83
84
85
Pipeline for text-to-image generation with wild cards using Stable Diffusion.
86
87
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
88
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
89
90
Args:
91
vae ([`AutoencoderKL`]):
92
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
93
text_encoder ([`CLIPTextModel`]):
94
Frozen text-encoder. Stable Diffusion uses the text portion of
95
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
96
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
97
tokenizer (`CLIPTokenizer`):
98
Tokenizer of class
99
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
100
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
101
scheduler ([`SchedulerMixin`]):
102
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
103
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
104
safety_checker ([`StableDiffusionSafetyChecker`]):
105
Classification module that estimates whether generated images could be considered offensive or harmful.
106
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
107
feature_extractor ([`CLIPImageProcessor`]):
108
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
109
"""
110
111
def __init__(
112
self,
113
vae: AutoencoderKL,
114
text_encoder: CLIPTextModel,
115
tokenizer: CLIPTokenizer,
116
unet: UNet2DConditionModel,
117
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
118
safety_checker: StableDiffusionSafetyChecker,
119
feature_extractor: CLIPImageProcessor,
120
):
121
super().__init__()
122
123
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
124
deprecation_message = (
125
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
126
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
127
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
128
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
129
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
130
" file"
131
)
132
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
133
new_config = dict(scheduler.config)
134
new_config["steps_offset"] = 1
135
scheduler._internal_dict = FrozenDict(new_config)
136
137
if safety_checker is None:
138
logger.warning(
139
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
140
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
141
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
142
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
143
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
144
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
145
)
146
147
self.register_modules(
148
vae=vae,
149
text_encoder=text_encoder,
150
tokenizer=tokenizer,
151
unet=unet,
152
scheduler=scheduler,
153
safety_checker=safety_checker,
154
feature_extractor=feature_extractor,
155
)
156
157
@torch.no_grad()
158
def __call__(
159
self,
160
prompt: Union[str, List[str]],
161
height: int = 512,
162
width: int = 512,
163
num_inference_steps: int = 50,
164
guidance_scale: float = 7.5,
165
negative_prompt: Optional[Union[str, List[str]]] = None,
166
num_images_per_prompt: Optional[int] = 1,
167
eta: float = 0.0,
168
generator: Optional[torch.Generator] = None,
169
latents: Optional[torch.FloatTensor] = None,
170
output_type: Optional[str] = "pil",
171
return_dict: bool = True,
172
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
173
callback_steps: int = 1,
174
wildcard_option_dict: Dict[str, List[str]] = {},
175
wildcard_files: List[str] = [],
176
num_prompt_samples: Optional[int] = 1,
177
**kwargs,
178
):
179
r"""
180
Function invoked when calling the pipeline for generation.
181
182
Args:
183
prompt (`str` or `List[str]`):
184
The prompt or prompts to guide the image generation.
185
height (`int`, *optional*, defaults to 512):
186
The height in pixels of the generated image.
187
width (`int`, *optional*, defaults to 512):
188
The width in pixels of the generated image.
189
num_inference_steps (`int`, *optional*, defaults to 50):
190
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191
expense of slower inference.
192
guidance_scale (`float`, *optional*, defaults to 7.5):
193
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
194
`guidance_scale` is defined as `w` of equation 2. of [Imagen
195
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
196
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
197
usually at the expense of lower image quality.
198
negative_prompt (`str` or `List[str]`, *optional*):
199
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
200
if `guidance_scale` is less than `1`).
201
num_images_per_prompt (`int`, *optional*, defaults to 1):
202
The number of images to generate per prompt.
203
eta (`float`, *optional*, defaults to 0.0):
204
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
205
[`schedulers.DDIMScheduler`], will be ignored for others.
206
generator (`torch.Generator`, *optional*):
207
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
208
deterministic.
209
latents (`torch.FloatTensor`, *optional*):
210
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
211
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
212
tensor will ge generated by sampling using the supplied random `generator`.
213
output_type (`str`, *optional*, defaults to `"pil"`):
214
The output format of the generate image. Choose between
215
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
216
return_dict (`bool`, *optional*, defaults to `True`):
217
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
218
plain tuple.
219
callback (`Callable`, *optional*):
220
A function that will be called every `callback_steps` steps during inference. The function will be
221
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
222
callback_steps (`int`, *optional*, defaults to 1):
223
The frequency at which the `callback` function will be called. If not specified, the callback will be
224
called at every step.
225
wildcard_option_dict (Dict[str, List[str]]):
226
dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, "A __animal__ sitting on a chair". A wildcard_option_dict can provide possible values for "animal" like this: {"animal":["dog", "cat", "fox"]}
227
wildcard_files: (List[str])
228
List of filenames of txt files for wildcard replacements. For example if a prompt, "A __animal__ sitting on a chair". A file can be provided ["animal.txt"]
229
num_prompt_samples: int
230
Number of times to sample wildcards for each prompt provided
231
232
Returns:
233
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
234
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
235
When returning a tuple, the first element is a list with the generated images, and the second element is a
236
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
237
(nsfw) content, according to the `safety_checker`.
238
"""
239
240
if isinstance(prompt, str):
241
prompt = [
242
replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files)
243
for i in range(num_prompt_samples)
244
]
245
batch_size = len(prompt)
246
elif isinstance(prompt, list):
247
prompt_list = []
248
for p in prompt:
249
for i in range(num_prompt_samples):
250
prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files))
251
prompt = prompt_list
252
batch_size = len(prompt)
253
else:
254
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
255
256
if height % 8 != 0 or width % 8 != 0:
257
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
258
259
if (callback_steps is None) or (
260
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
261
):
262
raise ValueError(
263
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
264
f" {type(callback_steps)}."
265
)
266
267
# get prompt text embeddings
268
text_inputs = self.tokenizer(
269
prompt,
270
padding="max_length",
271
max_length=self.tokenizer.model_max_length,
272
return_tensors="pt",
273
)
274
text_input_ids = text_inputs.input_ids
275
276
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
277
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
278
logger.warning(
279
"The following part of your input was truncated because CLIP can only handle sequences up to"
280
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
281
)
282
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
283
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
284
285
# duplicate text embeddings for each generation per prompt, using mps friendly method
286
bs_embed, seq_len, _ = text_embeddings.shape
287
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
288
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
289
290
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
291
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
292
# corresponds to doing no classifier free guidance.
293
do_classifier_free_guidance = guidance_scale > 1.0
294
# get unconditional embeddings for classifier free guidance
295
if do_classifier_free_guidance:
296
uncond_tokens: List[str]
297
if negative_prompt is None:
298
uncond_tokens = [""] * batch_size
299
elif type(prompt) is not type(negative_prompt):
300
raise TypeError(
301
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
302
f" {type(prompt)}."
303
)
304
elif isinstance(negative_prompt, str):
305
uncond_tokens = [negative_prompt]
306
elif batch_size != len(negative_prompt):
307
raise ValueError(
308
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
309
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
310
" the batch size of `prompt`."
311
)
312
else:
313
uncond_tokens = negative_prompt
314
315
max_length = text_input_ids.shape[-1]
316
uncond_input = self.tokenizer(
317
uncond_tokens,
318
padding="max_length",
319
max_length=max_length,
320
truncation=True,
321
return_tensors="pt",
322
)
323
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
324
325
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
326
seq_len = uncond_embeddings.shape[1]
327
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
328
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
329
330
# For classifier free guidance, we need to do two forward passes.
331
# Here we concatenate the unconditional and text embeddings into a single batch
332
# to avoid doing two forward passes
333
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
334
335
# get the initial random noise unless the user supplied it
336
337
# Unlike in other pipelines, latents need to be generated in the target device
338
# for 1-to-1 results reproducibility with the CompVis implementation.
339
# However this currently doesn't work in `mps`.
340
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
341
latents_dtype = text_embeddings.dtype
342
if latents is None:
343
if self.device.type == "mps":
344
# randn does not exist on mps
345
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
346
self.device
347
)
348
else:
349
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
350
else:
351
if latents.shape != latents_shape:
352
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
353
latents = latents.to(self.device)
354
355
# set timesteps
356
self.scheduler.set_timesteps(num_inference_steps)
357
358
# Some schedulers like PNDM have timesteps as arrays
359
# It's more optimized to move all timesteps to correct device beforehand
360
timesteps_tensor = self.scheduler.timesteps.to(self.device)
361
362
# scale the initial noise by the standard deviation required by the scheduler
363
latents = latents * self.scheduler.init_noise_sigma
364
365
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
366
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
367
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
368
# and should be between [0, 1]
369
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
370
extra_step_kwargs = {}
371
if accepts_eta:
372
extra_step_kwargs["eta"] = eta
373
374
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
375
# expand the latents if we are doing classifier free guidance
376
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
377
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
378
379
# predict the noise residual
380
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
381
382
# perform guidance
383
if do_classifier_free_guidance:
384
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
385
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
386
387
# compute the previous noisy sample x_t -> x_t-1
388
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
389
390
# call the callback, if provided
391
if callback is not None and i % callback_steps == 0:
392
callback(i, t, latents)
393
394
latents = 1 / 0.18215 * latents
395
image = self.vae.decode(latents).sample
396
397
image = (image / 2 + 0.5).clamp(0, 1)
398
399
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
400
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
401
402
if self.safety_checker is not None:
403
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
404
self.device
405
)
406
image, has_nsfw_concept = self.safety_checker(
407
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
408
)
409
else:
410
has_nsfw_concept = None
411
412
if output_type == "pil":
413
image = self.numpy_to_pil(image)
414
415
if not return_dict:
416
return (image, has_nsfw_concept)
417
418
return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt)
419
420