Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/img2img_inpainting.py
1448 views
1
import inspect
2
from typing import Callable, List, Optional, Tuple, Union
3
4
import numpy as np
5
import PIL
6
import torch
7
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
8
9
from diffusers import DiffusionPipeline
10
from diffusers.configuration_utils import FrozenDict
11
from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
from diffusers.utils import deprecate, logging
16
17
18
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
20
21
def prepare_mask_and_masked_image(image, mask):
22
image = np.array(image.convert("RGB"))
23
image = image[None].transpose(0, 3, 1, 2)
24
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
25
26
mask = np.array(mask.convert("L"))
27
mask = mask.astype(np.float32) / 255.0
28
mask = mask[None, None]
29
mask[mask < 0.5] = 0
30
mask[mask >= 0.5] = 1
31
mask = torch.from_numpy(mask)
32
33
masked_image = image * (mask < 0.5)
34
35
return mask, masked_image
36
37
38
def check_size(image, height, width):
39
if isinstance(image, PIL.Image.Image):
40
w, h = image.size
41
elif isinstance(image, torch.Tensor):
42
*_, h, w = image.shape
43
44
if h != height or w != width:
45
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
46
47
48
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
49
inner_image = inner_image.convert("RGBA")
50
image = image.convert("RGB")
51
52
image.paste(inner_image, paste_offset, inner_image)
53
image = image.convert("RGB")
54
55
return image
56
57
58
class ImageToImageInpaintingPipeline(DiffusionPipeline):
59
r"""
60
Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.
61
62
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
63
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
64
65
Args:
66
vae ([`AutoencoderKL`]):
67
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68
text_encoder ([`CLIPTextModel`]):
69
Frozen text-encoder. Stable Diffusion uses the text portion of
70
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
71
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
72
tokenizer (`CLIPTokenizer`):
73
Tokenizer of class
74
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
75
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
76
scheduler ([`SchedulerMixin`]):
77
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
78
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
79
safety_checker ([`StableDiffusionSafetyChecker`]):
80
Classification module that estimates whether generated images could be considered offensive or harmful.
81
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
82
feature_extractor ([`CLIPImageProcessor`]):
83
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
84
"""
85
86
def __init__(
87
self,
88
vae: AutoencoderKL,
89
text_encoder: CLIPTextModel,
90
tokenizer: CLIPTokenizer,
91
unet: UNet2DConditionModel,
92
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
93
safety_checker: StableDiffusionSafetyChecker,
94
feature_extractor: CLIPImageProcessor,
95
):
96
super().__init__()
97
98
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
99
deprecation_message = (
100
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
101
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
102
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
103
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
104
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
105
" file"
106
)
107
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
108
new_config = dict(scheduler.config)
109
new_config["steps_offset"] = 1
110
scheduler._internal_dict = FrozenDict(new_config)
111
112
if safety_checker is None:
113
logger.warning(
114
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
115
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
116
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
117
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
118
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
119
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
120
)
121
122
self.register_modules(
123
vae=vae,
124
text_encoder=text_encoder,
125
tokenizer=tokenizer,
126
unet=unet,
127
scheduler=scheduler,
128
safety_checker=safety_checker,
129
feature_extractor=feature_extractor,
130
)
131
132
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
133
r"""
134
Enable sliced attention computation.
135
136
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
137
in several steps. This is useful to save some memory in exchange for a small speed decrease.
138
139
Args:
140
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
141
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
142
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
143
`attention_head_dim` must be a multiple of `slice_size`.
144
"""
145
if slice_size == "auto":
146
# half the attention head size is usually a good trade-off between
147
# speed and memory
148
slice_size = self.unet.config.attention_head_dim // 2
149
self.unet.set_attention_slice(slice_size)
150
151
def disable_attention_slicing(self):
152
r"""
153
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
154
back to computing attention in one step.
155
"""
156
# set slice_size = `None` to disable `attention slicing`
157
self.enable_attention_slicing(None)
158
159
@torch.no_grad()
160
def __call__(
161
self,
162
prompt: Union[str, List[str]],
163
image: Union[torch.FloatTensor, PIL.Image.Image],
164
inner_image: Union[torch.FloatTensor, PIL.Image.Image],
165
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
166
height: int = 512,
167
width: int = 512,
168
num_inference_steps: int = 50,
169
guidance_scale: float = 7.5,
170
negative_prompt: Optional[Union[str, List[str]]] = None,
171
num_images_per_prompt: Optional[int] = 1,
172
eta: float = 0.0,
173
generator: Optional[torch.Generator] = None,
174
latents: Optional[torch.FloatTensor] = None,
175
output_type: Optional[str] = "pil",
176
return_dict: bool = True,
177
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
178
callback_steps: int = 1,
179
**kwargs,
180
):
181
r"""
182
Function invoked when calling the pipeline for generation.
183
184
Args:
185
prompt (`str` or `List[str]`):
186
The prompt or prompts to guide the image generation.
187
image (`torch.Tensor` or `PIL.Image.Image`):
188
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
189
be masked out with `mask_image` and repainted according to `prompt`.
190
inner_image (`torch.Tensor` or `PIL.Image.Image`):
191
`Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent
192
regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with
193
the last channel representing the alpha channel, which will be used to blend `inner_image` with
194
`image`. If not provided, it will be forcibly cast to RGBA.
195
mask_image (`PIL.Image.Image`):
196
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
197
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
198
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
199
instead of 3, so the expected shape would be `(B, H, W, 1)`.
200
height (`int`, *optional*, defaults to 512):
201
The height in pixels of the generated image.
202
width (`int`, *optional*, defaults to 512):
203
The width in pixels of the generated image.
204
num_inference_steps (`int`, *optional*, defaults to 50):
205
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
206
expense of slower inference.
207
guidance_scale (`float`, *optional*, defaults to 7.5):
208
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
209
`guidance_scale` is defined as `w` of equation 2. of [Imagen
210
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
211
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
212
usually at the expense of lower image quality.
213
negative_prompt (`str` or `List[str]`, *optional*):
214
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
215
if `guidance_scale` is less than `1`).
216
num_images_per_prompt (`int`, *optional*, defaults to 1):
217
The number of images to generate per prompt.
218
eta (`float`, *optional*, defaults to 0.0):
219
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
220
[`schedulers.DDIMScheduler`], will be ignored for others.
221
generator (`torch.Generator`, *optional*):
222
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
223
deterministic.
224
latents (`torch.FloatTensor`, *optional*):
225
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
226
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
227
tensor will ge generated by sampling using the supplied random `generator`.
228
output_type (`str`, *optional*, defaults to `"pil"`):
229
The output format of the generate image. Choose between
230
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
231
return_dict (`bool`, *optional*, defaults to `True`):
232
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
233
plain tuple.
234
callback (`Callable`, *optional*):
235
A function that will be called every `callback_steps` steps during inference. The function will be
236
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
237
callback_steps (`int`, *optional*, defaults to 1):
238
The frequency at which the `callback` function will be called. If not specified, the callback will be
239
called at every step.
240
241
Returns:
242
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
243
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
244
When returning a tuple, the first element is a list with the generated images, and the second element is a
245
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
246
(nsfw) content, according to the `safety_checker`.
247
"""
248
249
if isinstance(prompt, str):
250
batch_size = 1
251
elif isinstance(prompt, list):
252
batch_size = len(prompt)
253
else:
254
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
255
256
if height % 8 != 0 or width % 8 != 0:
257
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
258
259
if (callback_steps is None) or (
260
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
261
):
262
raise ValueError(
263
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
264
f" {type(callback_steps)}."
265
)
266
267
# check if input sizes are correct
268
check_size(image, height, width)
269
check_size(inner_image, height, width)
270
check_size(mask_image, height, width)
271
272
# get prompt text embeddings
273
text_inputs = self.tokenizer(
274
prompt,
275
padding="max_length",
276
max_length=self.tokenizer.model_max_length,
277
return_tensors="pt",
278
)
279
text_input_ids = text_inputs.input_ids
280
281
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
282
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
283
logger.warning(
284
"The following part of your input was truncated because CLIP can only handle sequences up to"
285
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
286
)
287
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
288
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
289
290
# duplicate text embeddings for each generation per prompt, using mps friendly method
291
bs_embed, seq_len, _ = text_embeddings.shape
292
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
293
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
294
295
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
296
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
297
# corresponds to doing no classifier free guidance.
298
do_classifier_free_guidance = guidance_scale > 1.0
299
# get unconditional embeddings for classifier free guidance
300
if do_classifier_free_guidance:
301
uncond_tokens: List[str]
302
if negative_prompt is None:
303
uncond_tokens = [""]
304
elif type(prompt) is not type(negative_prompt):
305
raise TypeError(
306
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
307
f" {type(prompt)}."
308
)
309
elif isinstance(negative_prompt, str):
310
uncond_tokens = [negative_prompt]
311
elif batch_size != len(negative_prompt):
312
raise ValueError(
313
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
314
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
315
" the batch size of `prompt`."
316
)
317
else:
318
uncond_tokens = negative_prompt
319
320
max_length = text_input_ids.shape[-1]
321
uncond_input = self.tokenizer(
322
uncond_tokens,
323
padding="max_length",
324
max_length=max_length,
325
truncation=True,
326
return_tensors="pt",
327
)
328
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
329
330
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
seq_len = uncond_embeddings.shape[1]
332
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
333
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
334
335
# For classifier free guidance, we need to do two forward passes.
336
# Here we concatenate the unconditional and text embeddings into a single batch
337
# to avoid doing two forward passes
338
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
339
340
# get the initial random noise unless the user supplied it
341
# Unlike in other pipelines, latents need to be generated in the target device
342
# for 1-to-1 results reproducibility with the CompVis implementation.
343
# However this currently doesn't work in `mps`.
344
num_channels_latents = self.vae.config.latent_channels
345
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
346
latents_dtype = text_embeddings.dtype
347
if latents is None:
348
if self.device.type == "mps":
349
# randn does not exist on mps
350
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
351
self.device
352
)
353
else:
354
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
355
else:
356
if latents.shape != latents_shape:
357
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
358
latents = latents.to(self.device)
359
360
# overlay the inner image
361
image = overlay_inner_image(image, inner_image)
362
363
# prepare mask and masked_image
364
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
365
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
366
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
367
368
# resize the mask to latents shape as we concatenate the mask to the latents
369
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
370
371
# encode the mask image into latents space so we can concatenate it to the latents
372
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
373
masked_image_latents = 0.18215 * masked_image_latents
374
375
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
376
mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
377
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
378
379
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
380
masked_image_latents = (
381
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
382
)
383
384
num_channels_mask = mask.shape[1]
385
num_channels_masked_image = masked_image_latents.shape[1]
386
387
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
388
raise ValueError(
389
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
390
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
391
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
392
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
393
" `pipeline.unet` or your `mask_image` or `image` input."
394
)
395
396
# set timesteps
397
self.scheduler.set_timesteps(num_inference_steps)
398
399
# Some schedulers like PNDM have timesteps as arrays
400
# It's more optimized to move all timesteps to correct device beforehand
401
timesteps_tensor = self.scheduler.timesteps.to(self.device)
402
403
# scale the initial noise by the standard deviation required by the scheduler
404
latents = latents * self.scheduler.init_noise_sigma
405
406
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
407
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
408
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
409
# and should be between [0, 1]
410
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
411
extra_step_kwargs = {}
412
if accepts_eta:
413
extra_step_kwargs["eta"] = eta
414
415
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
416
# expand the latents if we are doing classifier free guidance
417
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
418
419
# concat latents, mask, masked_image_latents in the channel dimension
420
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
421
422
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
423
424
# predict the noise residual
425
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
426
427
# perform guidance
428
if do_classifier_free_guidance:
429
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
430
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
431
432
# compute the previous noisy sample x_t -> x_t-1
433
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
434
435
# call the callback, if provided
436
if callback is not None and i % callback_steps == 0:
437
callback(i, t, latents)
438
439
latents = 1 / 0.18215 * latents
440
image = self.vae.decode(latents).sample
441
442
image = (image / 2 + 0.5).clamp(0, 1)
443
444
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
445
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
446
447
if self.safety_checker is not None:
448
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
449
self.device
450
)
451
image, has_nsfw_concept = self.safety_checker(
452
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
453
)
454
else:
455
has_nsfw_concept = None
456
457
if output_type == "pil":
458
image = self.numpy_to_pil(image)
459
460
if not return_dict:
461
return (image, has_nsfw_concept)
462
463
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
464
465