Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/stable_diffusion_mega.py
1448 views
1
from typing import Any, Callable, Dict, List, Optional, Union
2
3
import PIL.Image
4
import torch
5
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6
7
from diffusers import (
8
AutoencoderKL,
9
DDIMScheduler,
10
DiffusionPipeline,
11
LMSDiscreteScheduler,
12
PNDMScheduler,
13
StableDiffusionImg2ImgPipeline,
14
StableDiffusionInpaintPipelineLegacy,
15
StableDiffusionPipeline,
16
UNet2DConditionModel,
17
)
18
from diffusers.configuration_utils import FrozenDict
19
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
from diffusers.utils import deprecate, logging
21
22
23
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
25
26
class StableDiffusionMegaPipeline(DiffusionPipeline):
27
r"""
28
Pipeline for text-to-image generation using Stable Diffusion.
29
30
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
31
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
32
33
Args:
34
vae ([`AutoencoderKL`]):
35
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
36
text_encoder ([`CLIPTextModel`]):
37
Frozen text-encoder. Stable Diffusion uses the text portion of
38
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
39
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
40
tokenizer (`CLIPTokenizer`):
41
Tokenizer of class
42
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
43
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
44
scheduler ([`SchedulerMixin`]):
45
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
46
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
47
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
48
Classification module that estimates whether generated images could be considered offensive or harmful.
49
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
50
feature_extractor ([`CLIPImageProcessor`]):
51
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
52
"""
53
_optional_components = ["safety_checker", "feature_extractor"]
54
55
def __init__(
56
self,
57
vae: AutoencoderKL,
58
text_encoder: CLIPTextModel,
59
tokenizer: CLIPTokenizer,
60
unet: UNet2DConditionModel,
61
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
62
safety_checker: StableDiffusionSafetyChecker,
63
feature_extractor: CLIPImageProcessor,
64
requires_safety_checker: bool = True,
65
):
66
super().__init__()
67
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
68
deprecation_message = (
69
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
70
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
71
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
72
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
73
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
74
" file"
75
)
76
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
77
new_config = dict(scheduler.config)
78
new_config["steps_offset"] = 1
79
scheduler._internal_dict = FrozenDict(new_config)
80
81
self.register_modules(
82
vae=vae,
83
text_encoder=text_encoder,
84
tokenizer=tokenizer,
85
unet=unet,
86
scheduler=scheduler,
87
safety_checker=safety_checker,
88
feature_extractor=feature_extractor,
89
)
90
self.register_to_config(requires_safety_checker=requires_safety_checker)
91
92
@property
93
def components(self) -> Dict[str, Any]:
94
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
95
96
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
97
r"""
98
Enable sliced attention computation.
99
100
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
101
in several steps. This is useful to save some memory in exchange for a small speed decrease.
102
103
Args:
104
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
105
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
106
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
107
`attention_head_dim` must be a multiple of `slice_size`.
108
"""
109
if slice_size == "auto":
110
# half the attention head size is usually a good trade-off between
111
# speed and memory
112
slice_size = self.unet.config.attention_head_dim // 2
113
self.unet.set_attention_slice(slice_size)
114
115
def disable_attention_slicing(self):
116
r"""
117
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
118
back to computing attention in one step.
119
"""
120
# set slice_size = `None` to disable `attention slicing`
121
self.enable_attention_slicing(None)
122
123
@torch.no_grad()
124
def inpaint(
125
self,
126
prompt: Union[str, List[str]],
127
image: Union[torch.FloatTensor, PIL.Image.Image],
128
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
129
strength: float = 0.8,
130
num_inference_steps: Optional[int] = 50,
131
guidance_scale: Optional[float] = 7.5,
132
negative_prompt: Optional[Union[str, List[str]]] = None,
133
num_images_per_prompt: Optional[int] = 1,
134
eta: Optional[float] = 0.0,
135
generator: Optional[torch.Generator] = None,
136
output_type: Optional[str] = "pil",
137
return_dict: bool = True,
138
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
139
callback_steps: int = 1,
140
):
141
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
142
return StableDiffusionInpaintPipelineLegacy(**self.components)(
143
prompt=prompt,
144
image=image,
145
mask_image=mask_image,
146
strength=strength,
147
num_inference_steps=num_inference_steps,
148
guidance_scale=guidance_scale,
149
negative_prompt=negative_prompt,
150
num_images_per_prompt=num_images_per_prompt,
151
eta=eta,
152
generator=generator,
153
output_type=output_type,
154
return_dict=return_dict,
155
callback=callback,
156
)
157
158
@torch.no_grad()
159
def img2img(
160
self,
161
prompt: Union[str, List[str]],
162
image: Union[torch.FloatTensor, PIL.Image.Image],
163
strength: float = 0.8,
164
num_inference_steps: Optional[int] = 50,
165
guidance_scale: Optional[float] = 7.5,
166
negative_prompt: Optional[Union[str, List[str]]] = None,
167
num_images_per_prompt: Optional[int] = 1,
168
eta: Optional[float] = 0.0,
169
generator: Optional[torch.Generator] = None,
170
output_type: Optional[str] = "pil",
171
return_dict: bool = True,
172
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
173
callback_steps: int = 1,
174
**kwargs,
175
):
176
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
177
return StableDiffusionImg2ImgPipeline(**self.components)(
178
prompt=prompt,
179
image=image,
180
strength=strength,
181
num_inference_steps=num_inference_steps,
182
guidance_scale=guidance_scale,
183
negative_prompt=negative_prompt,
184
num_images_per_prompt=num_images_per_prompt,
185
eta=eta,
186
generator=generator,
187
output_type=output_type,
188
return_dict=return_dict,
189
callback=callback,
190
callback_steps=callback_steps,
191
)
192
193
@torch.no_grad()
194
def text2img(
195
self,
196
prompt: Union[str, List[str]],
197
height: int = 512,
198
width: int = 512,
199
num_inference_steps: int = 50,
200
guidance_scale: float = 7.5,
201
negative_prompt: Optional[Union[str, List[str]]] = None,
202
num_images_per_prompt: Optional[int] = 1,
203
eta: float = 0.0,
204
generator: Optional[torch.Generator] = None,
205
latents: Optional[torch.FloatTensor] = None,
206
output_type: Optional[str] = "pil",
207
return_dict: bool = True,
208
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
209
callback_steps: int = 1,
210
):
211
# For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
212
return StableDiffusionPipeline(**self.components)(
213
prompt=prompt,
214
height=height,
215
width=width,
216
num_inference_steps=num_inference_steps,
217
guidance_scale=guidance_scale,
218
negative_prompt=negative_prompt,
219
num_images_per_prompt=num_images_per_prompt,
220
eta=eta,
221
generator=generator,
222
latents=latents,
223
output_type=output_type,
224
return_dict=return_dict,
225
callback=callback,
226
callback_steps=callback_steps,
227
)
228
229