Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/stable_diffusion_comparison.py
1448 views
1
from typing import Any, Callable, Dict, List, Optional, Union
2
3
import torch
4
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
5
6
from diffusers import (
7
AutoencoderKL,
8
DDIMScheduler,
9
DiffusionPipeline,
10
LMSDiscreteScheduler,
11
PNDMScheduler,
12
StableDiffusionPipeline,
13
UNet2DConditionModel,
14
)
15
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
18
19
pipe1_model_id = "CompVis/stable-diffusion-v1-1"
20
pipe2_model_id = "CompVis/stable-diffusion-v1-2"
21
pipe3_model_id = "CompVis/stable-diffusion-v1-3"
22
pipe4_model_id = "CompVis/stable-diffusion-v1-4"
23
24
25
class StableDiffusionComparisonPipeline(DiffusionPipeline):
26
r"""
27
Pipeline for parallel comparison of Stable Diffusion v1-v4
28
This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
29
downloading pre-trained checkpoints from Hugging Face Hub.
30
If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
31
automatically.
32
Args:
33
vae ([`AutoencoderKL`]):
34
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
35
text_encoder ([`CLIPTextModel`]):
36
Frozen text-encoder. Stable Diffusion uses the text portion of
37
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
38
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
39
tokenizer (`CLIPTokenizer`):
40
Tokenizer of class
41
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
42
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
43
scheduler ([`SchedulerMixin`]):
44
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
45
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
46
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
47
Classification module that estimates whether generated images could be considered offensive or harmful.
48
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
49
feature_extractor ([`CLIPImageProcessor`]):
50
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
51
"""
52
53
def __init__(
54
self,
55
vae: AutoencoderKL,
56
text_encoder: CLIPTextModel,
57
tokenizer: CLIPTokenizer,
58
unet: UNet2DConditionModel,
59
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
60
safety_checker: StableDiffusionSafetyChecker,
61
feature_extractor: CLIPImageProcessor,
62
requires_safety_checker: bool = True,
63
):
64
super()._init_()
65
66
self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
67
self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
68
self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
69
self.pipe4 = StableDiffusionPipeline(
70
vae=vae,
71
text_encoder=text_encoder,
72
tokenizer=tokenizer,
73
unet=unet,
74
scheduler=scheduler,
75
safety_checker=safety_checker,
76
feature_extractor=feature_extractor,
77
requires_safety_checker=requires_safety_checker,
78
)
79
80
self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
81
82
@property
83
def layers(self) -> Dict[str, Any]:
84
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
85
86
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
87
r"""
88
Enable sliced attention computation.
89
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
90
in several steps. This is useful to save some memory in exchange for a small speed decrease.
91
Args:
92
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
93
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
94
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
95
`attention_head_dim` must be a multiple of `slice_size`.
96
"""
97
if slice_size == "auto":
98
# half the attention head size is usually a good trade-off between
99
# speed and memory
100
slice_size = self.unet.config.attention_head_dim // 2
101
self.unet.set_attention_slice(slice_size)
102
103
def disable_attention_slicing(self):
104
r"""
105
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
106
back to computing attention in one step.
107
"""
108
# set slice_size = `None` to disable `attention slicing`
109
self.enable_attention_slicing(None)
110
111
@torch.no_grad()
112
def text2img_sd1_1(
113
self,
114
prompt: Union[str, List[str]],
115
height: int = 512,
116
width: int = 512,
117
num_inference_steps: int = 50,
118
guidance_scale: float = 7.5,
119
negative_prompt: Optional[Union[str, List[str]]] = None,
120
num_images_per_prompt: Optional[int] = 1,
121
eta: float = 0.0,
122
generator: Optional[torch.Generator] = None,
123
latents: Optional[torch.FloatTensor] = None,
124
output_type: Optional[str] = "pil",
125
return_dict: bool = True,
126
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
127
callback_steps: int = 1,
128
**kwargs,
129
):
130
return self.pipe1(
131
prompt=prompt,
132
height=height,
133
width=width,
134
num_inference_steps=num_inference_steps,
135
guidance_scale=guidance_scale,
136
negative_prompt=negative_prompt,
137
num_images_per_prompt=num_images_per_prompt,
138
eta=eta,
139
generator=generator,
140
latents=latents,
141
output_type=output_type,
142
return_dict=return_dict,
143
callback=callback,
144
callback_steps=callback_steps,
145
**kwargs,
146
)
147
148
@torch.no_grad()
149
def text2img_sd1_2(
150
self,
151
prompt: Union[str, List[str]],
152
height: int = 512,
153
width: int = 512,
154
num_inference_steps: int = 50,
155
guidance_scale: float = 7.5,
156
negative_prompt: Optional[Union[str, List[str]]] = None,
157
num_images_per_prompt: Optional[int] = 1,
158
eta: float = 0.0,
159
generator: Optional[torch.Generator] = None,
160
latents: Optional[torch.FloatTensor] = None,
161
output_type: Optional[str] = "pil",
162
return_dict: bool = True,
163
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
164
callback_steps: int = 1,
165
**kwargs,
166
):
167
return self.pipe2(
168
prompt=prompt,
169
height=height,
170
width=width,
171
num_inference_steps=num_inference_steps,
172
guidance_scale=guidance_scale,
173
negative_prompt=negative_prompt,
174
num_images_per_prompt=num_images_per_prompt,
175
eta=eta,
176
generator=generator,
177
latents=latents,
178
output_type=output_type,
179
return_dict=return_dict,
180
callback=callback,
181
callback_steps=callback_steps,
182
**kwargs,
183
)
184
185
@torch.no_grad()
186
def text2img_sd1_3(
187
self,
188
prompt: Union[str, List[str]],
189
height: int = 512,
190
width: int = 512,
191
num_inference_steps: int = 50,
192
guidance_scale: float = 7.5,
193
negative_prompt: Optional[Union[str, List[str]]] = None,
194
num_images_per_prompt: Optional[int] = 1,
195
eta: float = 0.0,
196
generator: Optional[torch.Generator] = None,
197
latents: Optional[torch.FloatTensor] = None,
198
output_type: Optional[str] = "pil",
199
return_dict: bool = True,
200
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
201
callback_steps: int = 1,
202
**kwargs,
203
):
204
return self.pipe3(
205
prompt=prompt,
206
height=height,
207
width=width,
208
num_inference_steps=num_inference_steps,
209
guidance_scale=guidance_scale,
210
negative_prompt=negative_prompt,
211
num_images_per_prompt=num_images_per_prompt,
212
eta=eta,
213
generator=generator,
214
latents=latents,
215
output_type=output_type,
216
return_dict=return_dict,
217
callback=callback,
218
callback_steps=callback_steps,
219
**kwargs,
220
)
221
222
@torch.no_grad()
223
def text2img_sd1_4(
224
self,
225
prompt: Union[str, List[str]],
226
height: int = 512,
227
width: int = 512,
228
num_inference_steps: int = 50,
229
guidance_scale: float = 7.5,
230
negative_prompt: Optional[Union[str, List[str]]] = None,
231
num_images_per_prompt: Optional[int] = 1,
232
eta: float = 0.0,
233
generator: Optional[torch.Generator] = None,
234
latents: Optional[torch.FloatTensor] = None,
235
output_type: Optional[str] = "pil",
236
return_dict: bool = True,
237
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
238
callback_steps: int = 1,
239
**kwargs,
240
):
241
return self.pipe4(
242
prompt=prompt,
243
height=height,
244
width=width,
245
num_inference_steps=num_inference_steps,
246
guidance_scale=guidance_scale,
247
negative_prompt=negative_prompt,
248
num_images_per_prompt=num_images_per_prompt,
249
eta=eta,
250
generator=generator,
251
latents=latents,
252
output_type=output_type,
253
return_dict=return_dict,
254
callback=callback,
255
callback_steps=callback_steps,
256
**kwargs,
257
)
258
259
@torch.no_grad()
260
def _call_(
261
self,
262
prompt: Union[str, List[str]],
263
height: int = 512,
264
width: int = 512,
265
num_inference_steps: int = 50,
266
guidance_scale: float = 7.5,
267
negative_prompt: Optional[Union[str, List[str]]] = None,
268
num_images_per_prompt: Optional[int] = 1,
269
eta: float = 0.0,
270
generator: Optional[torch.Generator] = None,
271
latents: Optional[torch.FloatTensor] = None,
272
output_type: Optional[str] = "pil",
273
return_dict: bool = True,
274
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
275
callback_steps: int = 1,
276
**kwargs,
277
):
278
r"""
279
Function invoked when calling the pipeline for generation. This function will generate 4 results as part
280
of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
281
Args:
282
prompt (`str` or `List[str]`):
283
The prompt or prompts to guide the image generation.
284
height (`int`, optional, defaults to 512):
285
The height in pixels of the generated image.
286
width (`int`, optional, defaults to 512):
287
The width in pixels of the generated image.
288
num_inference_steps (`int`, optional, defaults to 50):
289
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
290
expense of slower inference.
291
guidance_scale (`float`, optional, defaults to 7.5):
292
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
293
`guidance_scale` is defined as `w` of equation 2. of [Imagen
294
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
295
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
296
usually at the expense of lower image quality.
297
eta (`float`, optional, defaults to 0.0):
298
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
299
[`schedulers.DDIMScheduler`], will be ignored for others.
300
generator (`torch.Generator`, optional):
301
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
302
deterministic.
303
latents (`torch.FloatTensor`, optional):
304
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
305
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
306
tensor will ge generated by sampling using the supplied random `generator`.
307
output_type (`str`, optional, defaults to `"pil"`):
308
The output format of the generate image. Choose between
309
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
310
return_dict (`bool`, optional, defaults to `True`):
311
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
312
plain tuple.
313
Returns:
314
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
315
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
316
When returning a tuple, the first element is a list with the generated images, and the second element is a
317
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
318
(nsfw) content, according to the `safety_checker`.
319
"""
320
321
device = "cuda" if torch.cuda.is_available() else "cpu"
322
self.to(device)
323
324
# Checks if the height and width are divisible by 8 or not
325
if height % 8 != 0 or width % 8 != 0:
326
raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
327
328
# Get first result from Stable Diffusion Checkpoint v1.1
329
res1 = self.text2img_sd1_1(
330
prompt=prompt,
331
height=height,
332
width=width,
333
num_inference_steps=num_inference_steps,
334
guidance_scale=guidance_scale,
335
negative_prompt=negative_prompt,
336
num_images_per_prompt=num_images_per_prompt,
337
eta=eta,
338
generator=generator,
339
latents=latents,
340
output_type=output_type,
341
return_dict=return_dict,
342
callback=callback,
343
callback_steps=callback_steps,
344
**kwargs,
345
)
346
347
# Get first result from Stable Diffusion Checkpoint v1.2
348
res2 = self.text2img_sd1_2(
349
prompt=prompt,
350
height=height,
351
width=width,
352
num_inference_steps=num_inference_steps,
353
guidance_scale=guidance_scale,
354
negative_prompt=negative_prompt,
355
num_images_per_prompt=num_images_per_prompt,
356
eta=eta,
357
generator=generator,
358
latents=latents,
359
output_type=output_type,
360
return_dict=return_dict,
361
callback=callback,
362
callback_steps=callback_steps,
363
**kwargs,
364
)
365
366
# Get first result from Stable Diffusion Checkpoint v1.3
367
res3 = self.text2img_sd1_3(
368
prompt=prompt,
369
height=height,
370
width=width,
371
num_inference_steps=num_inference_steps,
372
guidance_scale=guidance_scale,
373
negative_prompt=negative_prompt,
374
num_images_per_prompt=num_images_per_prompt,
375
eta=eta,
376
generator=generator,
377
latents=latents,
378
output_type=output_type,
379
return_dict=return_dict,
380
callback=callback,
381
callback_steps=callback_steps,
382
**kwargs,
383
)
384
385
# Get first result from Stable Diffusion Checkpoint v1.4
386
res4 = self.text2img_sd1_4(
387
prompt=prompt,
388
height=height,
389
width=width,
390
num_inference_steps=num_inference_steps,
391
guidance_scale=guidance_scale,
392
negative_prompt=negative_prompt,
393
num_images_per_prompt=num_images_per_prompt,
394
eta=eta,
395
generator=generator,
396
latents=latents,
397
output_type=output_type,
398
return_dict=return_dict,
399
callback=callback,
400
callback_steps=callback_steps,
401
**kwargs,
402
)
403
404
# Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
405
return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
406
407