Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
1450 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import unittest
19
20
import numpy as np
21
import torch
22
from PIL import Image
23
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24
25
from diffusers import (
26
AutoencoderKL,
27
DPMSolverMultistepScheduler,
28
LMSDiscreteScheduler,
29
PNDMScheduler,
30
StableDiffusionInpaintPipeline,
31
UNet2DConditionModel,
32
)
33
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
34
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
35
from diffusers.utils.testing_utils import require_torch_gpu
36
37
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
38
from ...test_pipelines_common import PipelineTesterMixin
39
40
41
torch.backends.cuda.matmul.allow_tf32 = False
42
43
44
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
45
pipeline_class = StableDiffusionInpaintPipeline
46
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
47
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
48
49
def get_dummy_components(self):
50
torch.manual_seed(0)
51
unet = UNet2DConditionModel(
52
block_out_channels=(32, 64),
53
layers_per_block=2,
54
sample_size=32,
55
in_channels=9,
56
out_channels=4,
57
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
58
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
59
cross_attention_dim=32,
60
)
61
scheduler = PNDMScheduler(skip_prk_steps=True)
62
torch.manual_seed(0)
63
vae = AutoencoderKL(
64
block_out_channels=[32, 64],
65
in_channels=3,
66
out_channels=3,
67
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
68
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
69
latent_channels=4,
70
)
71
torch.manual_seed(0)
72
text_encoder_config = CLIPTextConfig(
73
bos_token_id=0,
74
eos_token_id=2,
75
hidden_size=32,
76
intermediate_size=37,
77
layer_norm_eps=1e-05,
78
num_attention_heads=4,
79
num_hidden_layers=5,
80
pad_token_id=1,
81
vocab_size=1000,
82
)
83
text_encoder = CLIPTextModel(text_encoder_config)
84
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
85
86
components = {
87
"unet": unet,
88
"scheduler": scheduler,
89
"vae": vae,
90
"text_encoder": text_encoder,
91
"tokenizer": tokenizer,
92
"safety_checker": None,
93
"feature_extractor": None,
94
}
95
return components
96
97
def get_dummy_inputs(self, device, seed=0):
98
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
99
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
100
image = image.cpu().permute(0, 2, 3, 1)[0]
101
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
102
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
103
if str(device).startswith("mps"):
104
generator = torch.manual_seed(seed)
105
else:
106
generator = torch.Generator(device=device).manual_seed(seed)
107
inputs = {
108
"prompt": "A painting of a squirrel eating a burger",
109
"image": init_image,
110
"mask_image": mask_image,
111
"generator": generator,
112
"num_inference_steps": 2,
113
"guidance_scale": 6.0,
114
"output_type": "numpy",
115
}
116
return inputs
117
118
def test_stable_diffusion_inpaint(self):
119
device = "cpu" # ensure determinism for the device-dependent torch.Generator
120
components = self.get_dummy_components()
121
sd_pipe = StableDiffusionInpaintPipeline(**components)
122
sd_pipe = sd_pipe.to(device)
123
sd_pipe.set_progress_bar_config(disable=None)
124
125
inputs = self.get_dummy_inputs(device)
126
image = sd_pipe(**inputs).images
127
image_slice = image[0, -3:, -3:, -1]
128
129
assert image.shape == (1, 64, 64, 3)
130
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
131
132
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
133
134
def test_stable_diffusion_inpaint_image_tensor(self):
135
device = "cpu" # ensure determinism for the device-dependent torch.Generator
136
components = self.get_dummy_components()
137
sd_pipe = StableDiffusionInpaintPipeline(**components)
138
sd_pipe = sd_pipe.to(device)
139
sd_pipe.set_progress_bar_config(disable=None)
140
141
inputs = self.get_dummy_inputs(device)
142
output = sd_pipe(**inputs)
143
out_pil = output.images
144
145
inputs = self.get_dummy_inputs(device)
146
inputs["image"] = torch.tensor(np.array(inputs["image"]) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0)
147
inputs["mask_image"] = torch.tensor(np.array(inputs["mask_image"]) / 255).permute(2, 0, 1)[:1].unsqueeze(0)
148
output = sd_pipe(**inputs)
149
out_tensor = output.images
150
151
assert out_pil.shape == (1, 64, 64, 3)
152
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
153
154
155
@slow
156
@require_torch_gpu
157
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
158
def setUp(self):
159
super().setUp()
160
161
def tearDown(self):
162
super().tearDown()
163
gc.collect()
164
torch.cuda.empty_cache()
165
166
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
167
generator = torch.Generator(device=generator_device).manual_seed(seed)
168
init_image = load_image(
169
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
170
"/stable_diffusion_inpaint/input_bench_image.png"
171
)
172
mask_image = load_image(
173
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
174
"/stable_diffusion_inpaint/input_bench_mask.png"
175
)
176
inputs = {
177
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
178
"image": init_image,
179
"mask_image": mask_image,
180
"generator": generator,
181
"num_inference_steps": 3,
182
"guidance_scale": 7.5,
183
"output_type": "numpy",
184
}
185
return inputs
186
187
def test_stable_diffusion_inpaint_ddim(self):
188
pipe = StableDiffusionInpaintPipeline.from_pretrained(
189
"runwayml/stable-diffusion-inpainting", safety_checker=None
190
)
191
pipe.to(torch_device)
192
pipe.set_progress_bar_config(disable=None)
193
pipe.enable_attention_slicing()
194
195
inputs = self.get_inputs(torch_device)
196
image = pipe(**inputs).images
197
image_slice = image[0, 253:256, 253:256, -1].flatten()
198
199
assert image.shape == (1, 512, 512, 3)
200
expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794])
201
202
assert np.abs(expected_slice - image_slice).max() < 1e-4
203
204
def test_stable_diffusion_inpaint_fp16(self):
205
pipe = StableDiffusionInpaintPipeline.from_pretrained(
206
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None
207
)
208
pipe.to(torch_device)
209
pipe.set_progress_bar_config(disable=None)
210
pipe.enable_attention_slicing()
211
212
inputs = self.get_inputs(torch_device, dtype=torch.float16)
213
image = pipe(**inputs).images
214
image_slice = image[0, 253:256, 253:256, -1].flatten()
215
216
assert image.shape == (1, 512, 512, 3)
217
expected_slice = np.array([0.1350, 0.1123, 0.1350, 0.1641, 0.1328, 0.1230, 0.1289, 0.1531, 0.1687])
218
219
assert np.abs(expected_slice - image_slice).max() < 5e-2
220
221
def test_stable_diffusion_inpaint_pndm(self):
222
pipe = StableDiffusionInpaintPipeline.from_pretrained(
223
"runwayml/stable-diffusion-inpainting", safety_checker=None
224
)
225
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
226
pipe.to(torch_device)
227
pipe.set_progress_bar_config(disable=None)
228
pipe.enable_attention_slicing()
229
230
inputs = self.get_inputs(torch_device)
231
image = pipe(**inputs).images
232
image_slice = image[0, 253:256, 253:256, -1].flatten()
233
234
assert image.shape == (1, 512, 512, 3)
235
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
236
237
assert np.abs(expected_slice - image_slice).max() < 1e-4
238
239
def test_stable_diffusion_inpaint_k_lms(self):
240
pipe = StableDiffusionInpaintPipeline.from_pretrained(
241
"runwayml/stable-diffusion-inpainting", safety_checker=None
242
)
243
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
244
pipe.to(torch_device)
245
pipe.set_progress_bar_config(disable=None)
246
pipe.enable_attention_slicing()
247
248
inputs = self.get_inputs(torch_device)
249
image = pipe(**inputs).images
250
image_slice = image[0, 253:256, 253:256, -1].flatten()
251
252
assert image.shape == (1, 512, 512, 3)
253
expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633])
254
255
assert np.abs(expected_slice - image_slice).max() < 1e-4
256
257
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
258
torch.cuda.empty_cache()
259
torch.cuda.reset_max_memory_allocated()
260
torch.cuda.reset_peak_memory_stats()
261
262
pipe = StableDiffusionInpaintPipeline.from_pretrained(
263
"runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16
264
)
265
pipe = pipe.to(torch_device)
266
pipe.set_progress_bar_config(disable=None)
267
pipe.enable_attention_slicing(1)
268
pipe.enable_sequential_cpu_offload()
269
270
inputs = self.get_inputs(torch_device, dtype=torch.float16)
271
_ = pipe(**inputs)
272
273
mem_bytes = torch.cuda.max_memory_allocated()
274
# make sure that less than 2.2 GB is allocated
275
assert mem_bytes < 2.2 * 10**9
276
277
278
@nightly
279
@require_torch_gpu
280
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
281
def tearDown(self):
282
super().tearDown()
283
gc.collect()
284
torch.cuda.empty_cache()
285
286
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
287
generator = torch.Generator(device=generator_device).manual_seed(seed)
288
init_image = load_image(
289
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
290
"/stable_diffusion_inpaint/input_bench_image.png"
291
)
292
mask_image = load_image(
293
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
294
"/stable_diffusion_inpaint/input_bench_mask.png"
295
)
296
inputs = {
297
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
298
"image": init_image,
299
"mask_image": mask_image,
300
"generator": generator,
301
"num_inference_steps": 50,
302
"guidance_scale": 7.5,
303
"output_type": "numpy",
304
}
305
return inputs
306
307
def test_inpaint_ddim(self):
308
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
309
sd_pipe.to(torch_device)
310
sd_pipe.set_progress_bar_config(disable=None)
311
312
inputs = self.get_inputs(torch_device)
313
image = sd_pipe(**inputs).images[0]
314
315
expected_image = load_numpy(
316
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
317
"/stable_diffusion_inpaint/stable_diffusion_inpaint_ddim.npy"
318
)
319
max_diff = np.abs(expected_image - image).max()
320
assert max_diff < 1e-3
321
322
def test_inpaint_pndm(self):
323
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
324
sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
325
sd_pipe.to(torch_device)
326
sd_pipe.set_progress_bar_config(disable=None)
327
328
inputs = self.get_inputs(torch_device)
329
image = sd_pipe(**inputs).images[0]
330
331
expected_image = load_numpy(
332
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
333
"/stable_diffusion_inpaint/stable_diffusion_inpaint_pndm.npy"
334
)
335
max_diff = np.abs(expected_image - image).max()
336
assert max_diff < 1e-3
337
338
def test_inpaint_lms(self):
339
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
340
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
341
sd_pipe.to(torch_device)
342
sd_pipe.set_progress_bar_config(disable=None)
343
344
inputs = self.get_inputs(torch_device)
345
image = sd_pipe(**inputs).images[0]
346
347
expected_image = load_numpy(
348
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
349
"/stable_diffusion_inpaint/stable_diffusion_inpaint_lms.npy"
350
)
351
max_diff = np.abs(expected_image - image).max()
352
assert max_diff < 1e-3
353
354
def test_inpaint_dpm(self):
355
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
356
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
357
sd_pipe.to(torch_device)
358
sd_pipe.set_progress_bar_config(disable=None)
359
360
inputs = self.get_inputs(torch_device)
361
inputs["num_inference_steps"] = 30
362
image = sd_pipe(**inputs).images[0]
363
364
expected_image = load_numpy(
365
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
366
"/stable_diffusion_inpaint/stable_diffusion_inpaint_dpm_multi.npy"
367
)
368
max_diff = np.abs(expected_image - image).max()
369
assert max_diff < 1e-3
370
371
372
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
373
def test_pil_inputs(self):
374
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
375
im = Image.fromarray(im)
376
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
377
mask = Image.fromarray((mask * 255).astype(np.uint8))
378
379
t_mask, t_masked = prepare_mask_and_masked_image(im, mask)
380
381
self.assertTrue(isinstance(t_mask, torch.Tensor))
382
self.assertTrue(isinstance(t_masked, torch.Tensor))
383
384
self.assertEqual(t_mask.ndim, 4)
385
self.assertEqual(t_masked.ndim, 4)
386
387
self.assertEqual(t_mask.shape, (1, 1, 32, 32))
388
self.assertEqual(t_masked.shape, (1, 3, 32, 32))
389
390
self.assertTrue(t_mask.dtype == torch.float32)
391
self.assertTrue(t_masked.dtype == torch.float32)
392
393
self.assertTrue(t_mask.min() >= 0.0)
394
self.assertTrue(t_mask.max() <= 1.0)
395
self.assertTrue(t_masked.min() >= -1.0)
396
self.assertTrue(t_masked.min() <= 1.0)
397
398
self.assertTrue(t_mask.sum() > 0.0)
399
400
def test_np_inputs(self):
401
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
402
im_pil = Image.fromarray(im_np)
403
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
404
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
405
406
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
407
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)
408
409
self.assertTrue((t_mask_np == t_mask_pil).all())
410
self.assertTrue((t_masked_np == t_masked_pil).all())
411
412
def test_torch_3D_2D_inputs(self):
413
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
414
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
415
im_np = im_tensor.numpy().transpose(1, 2, 0)
416
mask_np = mask_tensor.numpy()
417
418
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
419
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
420
421
self.assertTrue((t_mask_tensor == t_mask_np).all())
422
self.assertTrue((t_masked_tensor == t_masked_np).all())
423
424
def test_torch_3D_3D_inputs(self):
425
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
426
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
427
im_np = im_tensor.numpy().transpose(1, 2, 0)
428
mask_np = mask_tensor.numpy()[0]
429
430
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
431
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
432
433
self.assertTrue((t_mask_tensor == t_mask_np).all())
434
self.assertTrue((t_masked_tensor == t_masked_np).all())
435
436
def test_torch_4D_2D_inputs(self):
437
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
438
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
439
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
440
mask_np = mask_tensor.numpy()
441
442
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
443
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
444
445
self.assertTrue((t_mask_tensor == t_mask_np).all())
446
self.assertTrue((t_masked_tensor == t_masked_np).all())
447
448
def test_torch_4D_3D_inputs(self):
449
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
450
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
451
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
452
mask_np = mask_tensor.numpy()[0]
453
454
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
455
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
456
457
self.assertTrue((t_mask_tensor == t_mask_np).all())
458
self.assertTrue((t_masked_tensor == t_masked_np).all())
459
460
def test_torch_4D_4D_inputs(self):
461
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
462
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
463
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
464
mask_np = mask_tensor.numpy()[0][0]
465
466
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
467
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
468
469
self.assertTrue((t_mask_tensor == t_mask_np).all())
470
self.assertTrue((t_masked_tensor == t_masked_np).all())
471
472
def test_torch_batch_4D_3D(self):
473
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
474
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
475
476
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
477
mask_nps = [mask.numpy() for mask in mask_tensor]
478
479
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
480
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
481
t_mask_np = torch.cat([n[0] for n in nps])
482
t_masked_np = torch.cat([n[1] for n in nps])
483
484
self.assertTrue((t_mask_tensor == t_mask_np).all())
485
self.assertTrue((t_masked_tensor == t_masked_np).all())
486
487
def test_torch_batch_4D_4D(self):
488
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
489
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
490
491
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
492
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
493
494
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
495
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
496
t_mask_np = torch.cat([n[0] for n in nps])
497
t_masked_np = torch.cat([n[1] for n in nps])
498
499
self.assertTrue((t_mask_tensor == t_mask_np).all())
500
self.assertTrue((t_masked_tensor == t_masked_np).all())
501
502
def test_shape_mismatch(self):
503
# test height and width
504
with self.assertRaises(AssertionError):
505
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
506
# test batch dim
507
with self.assertRaises(AssertionError):
508
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
509
# test batch dim
510
with self.assertRaises(AssertionError):
511
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))
512
513
def test_type_mismatch(self):
514
# test tensors-only
515
with self.assertRaises(TypeError):
516
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
517
# test tensors-only
518
with self.assertRaises(TypeError):
519
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))
520
521
def test_channels_first(self):
522
# test channels first for 3D tensors
523
with self.assertRaises(AssertionError):
524
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))
525
526
def test_tensor_range(self):
527
# test im <= 1
528
with self.assertRaises(ValueError):
529
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
530
# test im >= -1
531
with self.assertRaises(ValueError):
532
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
533
# test mask <= 1
534
with self.assertRaises(ValueError):
535
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
536
# test mask >= 0
537
with self.assertRaises(ValueError):
538
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)
539
540