Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import unittest
19
20
import numpy as np
21
import torch
22
from PIL import Image
23
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24
25
from diffusers import (
26
AutoencoderKL,
27
DDIMScheduler,
28
DPMSolverMultistepScheduler,
29
LMSDiscreteScheduler,
30
PNDMScheduler,
31
StableDiffusionInpaintPipelineLegacy,
32
UNet2DConditionModel,
33
UNet2DModel,
34
VQModel,
35
)
36
from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device
37
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu
38
39
40
torch.backends.cuda.matmul.allow_tf32 = False
41
42
43
class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
44
def tearDown(self):
45
# clean up the VRAM after each test
46
super().tearDown()
47
gc.collect()
48
torch.cuda.empty_cache()
49
50
@property
51
def dummy_image(self):
52
batch_size = 1
53
num_channels = 3
54
sizes = (32, 32)
55
56
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
57
return image
58
59
@property
60
def dummy_uncond_unet(self):
61
torch.manual_seed(0)
62
model = UNet2DModel(
63
block_out_channels=(32, 64),
64
layers_per_block=2,
65
sample_size=32,
66
in_channels=3,
67
out_channels=3,
68
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
69
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
70
)
71
return model
72
73
@property
74
def dummy_cond_unet(self):
75
torch.manual_seed(0)
76
model = UNet2DConditionModel(
77
block_out_channels=(32, 64),
78
layers_per_block=2,
79
sample_size=32,
80
in_channels=4,
81
out_channels=4,
82
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
83
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
84
cross_attention_dim=32,
85
)
86
return model
87
88
@property
89
def dummy_cond_unet_inpaint(self):
90
torch.manual_seed(0)
91
model = UNet2DConditionModel(
92
block_out_channels=(32, 64),
93
layers_per_block=2,
94
sample_size=32,
95
in_channels=9,
96
out_channels=4,
97
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
98
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
99
cross_attention_dim=32,
100
)
101
return model
102
103
@property
104
def dummy_vq_model(self):
105
torch.manual_seed(0)
106
model = VQModel(
107
block_out_channels=[32, 64],
108
in_channels=3,
109
out_channels=3,
110
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
111
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
112
latent_channels=3,
113
)
114
return model
115
116
@property
117
def dummy_vae(self):
118
torch.manual_seed(0)
119
model = AutoencoderKL(
120
block_out_channels=[32, 64],
121
in_channels=3,
122
out_channels=3,
123
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
124
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
125
latent_channels=4,
126
)
127
return model
128
129
@property
130
def dummy_text_encoder(self):
131
torch.manual_seed(0)
132
config = CLIPTextConfig(
133
bos_token_id=0,
134
eos_token_id=2,
135
hidden_size=32,
136
intermediate_size=37,
137
layer_norm_eps=1e-05,
138
num_attention_heads=4,
139
num_hidden_layers=5,
140
pad_token_id=1,
141
vocab_size=1000,
142
)
143
return CLIPTextModel(config)
144
145
@property
146
def dummy_extractor(self):
147
def extract(*args, **kwargs):
148
class Out:
149
def __init__(self):
150
self.pixel_values = torch.ones([0])
151
152
def to(self, device):
153
self.pixel_values.to(device)
154
return self
155
156
return Out()
157
158
return extract
159
160
def test_stable_diffusion_inpaint_legacy(self):
161
device = "cpu" # ensure determinism for the device-dependent torch.Generator
162
unet = self.dummy_cond_unet
163
scheduler = PNDMScheduler(skip_prk_steps=True)
164
vae = self.dummy_vae
165
bert = self.dummy_text_encoder
166
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
167
168
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
169
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
170
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
171
172
# make sure here that pndm scheduler skips prk
173
sd_pipe = StableDiffusionInpaintPipelineLegacy(
174
unet=unet,
175
scheduler=scheduler,
176
vae=vae,
177
text_encoder=bert,
178
tokenizer=tokenizer,
179
safety_checker=None,
180
feature_extractor=self.dummy_extractor,
181
)
182
sd_pipe = sd_pipe.to(device)
183
sd_pipe.set_progress_bar_config(disable=None)
184
185
prompt = "A painting of a squirrel eating a burger"
186
generator = torch.Generator(device=device).manual_seed(0)
187
output = sd_pipe(
188
[prompt],
189
generator=generator,
190
guidance_scale=6.0,
191
num_inference_steps=2,
192
output_type="np",
193
image=init_image,
194
mask_image=mask_image,
195
)
196
197
image = output.images
198
199
generator = torch.Generator(device=device).manual_seed(0)
200
image_from_tuple = sd_pipe(
201
[prompt],
202
generator=generator,
203
guidance_scale=6.0,
204
num_inference_steps=2,
205
output_type="np",
206
image=init_image,
207
mask_image=mask_image,
208
return_dict=False,
209
)[0]
210
211
image_slice = image[0, -3:, -3:, -1]
212
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
213
214
assert image.shape == (1, 32, 32, 3)
215
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
216
217
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
218
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
219
220
def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
221
device = "cpu" # ensure determinism for the device-dependent torch.Generator
222
unet = self.dummy_cond_unet
223
scheduler = PNDMScheduler(skip_prk_steps=True)
224
vae = self.dummy_vae
225
bert = self.dummy_text_encoder
226
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
227
228
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
229
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
230
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
231
232
# make sure here that pndm scheduler skips prk
233
sd_pipe = StableDiffusionInpaintPipelineLegacy(
234
unet=unet,
235
scheduler=scheduler,
236
vae=vae,
237
text_encoder=bert,
238
tokenizer=tokenizer,
239
safety_checker=None,
240
feature_extractor=self.dummy_extractor,
241
)
242
sd_pipe = sd_pipe.to(device)
243
sd_pipe.set_progress_bar_config(disable=None)
244
245
prompt = "A painting of a squirrel eating a burger"
246
negative_prompt = "french fries"
247
generator = torch.Generator(device=device).manual_seed(0)
248
output = sd_pipe(
249
prompt,
250
negative_prompt=negative_prompt,
251
generator=generator,
252
guidance_scale=6.0,
253
num_inference_steps=2,
254
output_type="np",
255
image=init_image,
256
mask_image=mask_image,
257
)
258
259
image = output.images
260
image_slice = image[0, -3:, -3:, -1]
261
262
assert image.shape == (1, 32, 32, 3)
263
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
264
265
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
266
267
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
268
device = "cpu"
269
unet = self.dummy_cond_unet
270
scheduler = PNDMScheduler(skip_prk_steps=True)
271
vae = self.dummy_vae
272
bert = self.dummy_text_encoder
273
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
274
275
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
276
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
277
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
278
279
# make sure here that pndm scheduler skips prk
280
sd_pipe = StableDiffusionInpaintPipelineLegacy(
281
unet=unet,
282
scheduler=scheduler,
283
vae=vae,
284
text_encoder=bert,
285
tokenizer=tokenizer,
286
safety_checker=None,
287
feature_extractor=self.dummy_extractor,
288
)
289
sd_pipe = sd_pipe.to(device)
290
sd_pipe.set_progress_bar_config(disable=None)
291
292
prompt = "A painting of a squirrel eating a burger"
293
294
# test num_images_per_prompt=1 (default)
295
images = sd_pipe(
296
prompt,
297
num_inference_steps=2,
298
output_type="np",
299
image=init_image,
300
mask_image=mask_image,
301
).images
302
303
assert images.shape == (1, 32, 32, 3)
304
305
# test num_images_per_prompt=1 (default) for batch of prompts
306
batch_size = 2
307
images = sd_pipe(
308
[prompt] * batch_size,
309
num_inference_steps=2,
310
output_type="np",
311
image=init_image,
312
mask_image=mask_image,
313
).images
314
315
assert images.shape == (batch_size, 32, 32, 3)
316
317
# test num_images_per_prompt for single prompt
318
num_images_per_prompt = 2
319
images = sd_pipe(
320
prompt,
321
num_inference_steps=2,
322
output_type="np",
323
image=init_image,
324
mask_image=mask_image,
325
num_images_per_prompt=num_images_per_prompt,
326
).images
327
328
assert images.shape == (num_images_per_prompt, 32, 32, 3)
329
330
# test num_images_per_prompt for batch of prompts
331
batch_size = 2
332
images = sd_pipe(
333
[prompt] * batch_size,
334
num_inference_steps=2,
335
output_type="np",
336
image=init_image,
337
mask_image=mask_image,
338
num_images_per_prompt=num_images_per_prompt,
339
).images
340
341
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
342
343
344
@slow
345
@require_torch_gpu
346
class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
347
def tearDown(self):
348
super().tearDown()
349
gc.collect()
350
torch.cuda.empty_cache()
351
352
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
353
generator = torch.Generator(device=generator_device).manual_seed(seed)
354
init_image = load_image(
355
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
356
"/stable_diffusion_inpaint/input_bench_image.png"
357
)
358
mask_image = load_image(
359
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
360
"/stable_diffusion_inpaint/input_bench_mask.png"
361
)
362
inputs = {
363
"prompt": "A red cat sitting on a park bench",
364
"image": init_image,
365
"mask_image": mask_image,
366
"generator": generator,
367
"num_inference_steps": 3,
368
"strength": 0.75,
369
"guidance_scale": 7.5,
370
"output_type": "numpy",
371
}
372
return inputs
373
374
def test_stable_diffusion_inpaint_legacy_pndm(self):
375
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
376
"CompVis/stable-diffusion-v1-4", safety_checker=None
377
)
378
pipe.to(torch_device)
379
pipe.set_progress_bar_config(disable=None)
380
pipe.enable_attention_slicing()
381
382
inputs = self.get_inputs(torch_device)
383
image = pipe(**inputs).images
384
image_slice = image[0, 253:256, 253:256, -1].flatten()
385
386
assert image.shape == (1, 512, 512, 3)
387
expected_slice = np.array([0.5665, 0.6117, 0.6430, 0.4057, 0.4594, 0.5658, 0.1596, 0.3106, 0.4305])
388
389
assert np.abs(expected_slice - image_slice).max() < 1e-4
390
391
def test_stable_diffusion_inpaint_legacy_k_lms(self):
392
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
393
"CompVis/stable-diffusion-v1-4", safety_checker=None
394
)
395
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
396
pipe.to(torch_device)
397
pipe.set_progress_bar_config(disable=None)
398
pipe.enable_attention_slicing()
399
400
inputs = self.get_inputs(torch_device)
401
image = pipe(**inputs).images
402
image_slice = image[0, 253:256, 253:256, -1].flatten()
403
404
assert image.shape == (1, 512, 512, 3)
405
expected_slice = np.array([0.4534, 0.4467, 0.4329, 0.4329, 0.4339, 0.4220, 0.4244, 0.4332, 0.4426])
406
407
assert np.abs(expected_slice - image_slice).max() < 1e-4
408
409
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
410
number_of_steps = 0
411
412
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
413
callback_fn.has_been_called = True
414
nonlocal number_of_steps
415
number_of_steps += 1
416
if step == 1:
417
latents = latents.detach().cpu().numpy()
418
assert latents.shape == (1, 4, 64, 64)
419
latents_slice = latents[0, -3:, -3:, -1]
420
expected_slice = np.array([0.5977, 1.5449, 1.0586, -0.3250, 0.7383, -0.0862, 0.4631, -0.2571, -1.1289])
421
422
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
423
elif step == 2:
424
latents = latents.detach().cpu().numpy()
425
assert latents.shape == (1, 4, 64, 64)
426
latents_slice = latents[0, -3:, -3:, -1]
427
expected_slice = np.array([0.5190, 1.1621, 0.6885, 0.2424, 0.3337, -0.1617, 0.6914, -0.1957, -0.5474])
428
429
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
430
431
callback_fn.has_been_called = False
432
433
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
434
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
435
)
436
pipe = pipe.to(torch_device)
437
pipe.set_progress_bar_config(disable=None)
438
pipe.enable_attention_slicing()
439
440
inputs = self.get_inputs(torch_device, dtype=torch.float16)
441
pipe(**inputs, callback=callback_fn, callback_steps=1)
442
assert callback_fn.has_been_called
443
assert number_of_steps == 2
444
445
446
@nightly
447
@require_torch_gpu
448
class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):
449
def tearDown(self):
450
super().tearDown()
451
gc.collect()
452
torch.cuda.empty_cache()
453
454
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
455
generator = torch.Generator(device=generator_device).manual_seed(seed)
456
init_image = load_image(
457
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
458
"/stable_diffusion_inpaint/input_bench_image.png"
459
)
460
mask_image = load_image(
461
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
462
"/stable_diffusion_inpaint/input_bench_mask.png"
463
)
464
inputs = {
465
"prompt": "A red cat sitting on a park bench",
466
"image": init_image,
467
"mask_image": mask_image,
468
"generator": generator,
469
"num_inference_steps": 50,
470
"strength": 0.75,
471
"guidance_scale": 7.5,
472
"output_type": "numpy",
473
}
474
return inputs
475
476
def test_inpaint_pndm(self):
477
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
478
sd_pipe.to(torch_device)
479
sd_pipe.set_progress_bar_config(disable=None)
480
481
inputs = self.get_inputs(torch_device)
482
image = sd_pipe(**inputs).images[0]
483
484
expected_image = load_numpy(
485
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
486
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_pndm.npy"
487
)
488
max_diff = np.abs(expected_image - image).max()
489
assert max_diff < 1e-3
490
491
def test_inpaint_ddim(self):
492
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
493
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
494
sd_pipe.to(torch_device)
495
sd_pipe.set_progress_bar_config(disable=None)
496
497
inputs = self.get_inputs(torch_device)
498
image = sd_pipe(**inputs).images[0]
499
500
expected_image = load_numpy(
501
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
502
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_ddim.npy"
503
)
504
max_diff = np.abs(expected_image - image).max()
505
assert max_diff < 1e-3
506
507
def test_inpaint_lms(self):
508
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
509
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
510
sd_pipe.to(torch_device)
511
sd_pipe.set_progress_bar_config(disable=None)
512
513
inputs = self.get_inputs(torch_device)
514
image = sd_pipe(**inputs).images[0]
515
516
expected_image = load_numpy(
517
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
518
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_lms.npy"
519
)
520
max_diff = np.abs(expected_image - image).max()
521
assert max_diff < 1e-3
522
523
def test_inpaint_dpm(self):
524
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
525
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
526
sd_pipe.to(torch_device)
527
sd_pipe.set_progress_bar_config(disable=None)
528
529
inputs = self.get_inputs(torch_device)
530
inputs["num_inference_steps"] = 30
531
image = sd_pipe(**inputs).images[0]
532
533
expected_image = load_numpy(
534
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
535
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_dpm_multi.npy"
536
)
537
max_diff = np.abs(expected_image - image).max()
538
assert max_diff < 1e-3
539
540