Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_pipelines_common.py
1440 views
1
import contextlib
2
import gc
3
import inspect
4
import io
5
import re
6
import tempfile
7
import unittest
8
from typing import Callable, Union
9
10
import numpy as np
11
import torch
12
13
import diffusers
14
from diffusers import DiffusionPipeline
15
from diffusers.utils import logging
16
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
17
from diffusers.utils.testing_utils import require_torch, torch_device
18
19
20
torch.backends.cuda.matmul.allow_tf32 = False
21
22
23
def to_np(tensor):
24
if isinstance(tensor, torch.Tensor):
25
tensor = tensor.detach().cpu().numpy()
26
27
return tensor
28
29
30
@require_torch
31
class PipelineTesterMixin:
32
"""
33
This mixin is designed to be used with unittest.TestCase classes.
34
It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
35
equivalence of dict and tuple outputs, etc.
36
"""
37
38
# Canonical parameters that are passed to `__call__` regardless
39
# of the type of pipeline. They are always optional and have common
40
# sense default values.
41
required_optional_params = frozenset(
42
[
43
"num_inference_steps",
44
"num_images_per_prompt",
45
"generator",
46
"latents",
47
"output_type",
48
"return_dict",
49
"callback",
50
"callback_steps",
51
]
52
)
53
54
# set these parameters to False in the child class if the pipeline does not support the corresponding functionality
55
test_attention_slicing = True
56
test_cpu_offload = True
57
test_xformers_attention = True
58
59
def get_generator(self, seed):
60
device = torch_device if torch_device != "mps" else "cpu"
61
generator = torch.Generator(device).manual_seed(seed)
62
return generator
63
64
@property
65
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
66
raise NotImplementedError(
67
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
68
"See existing pipeline tests for reference."
69
)
70
71
def get_dummy_components(self):
72
raise NotImplementedError(
73
"You need to implement `get_dummy_components(self)` in the child test class. "
74
"See existing pipeline tests for reference."
75
)
76
77
def get_dummy_inputs(self, device, seed=0):
78
raise NotImplementedError(
79
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
80
"See existing pipeline tests for reference."
81
)
82
83
@property
84
def params(self) -> frozenset:
85
raise NotImplementedError(
86
"You need to set the attribute `params` in the child test class. "
87
"`params` are checked for if all values are present in `__call__`'s signature."
88
" You can set `params` using one of the common set of parameters defined in`pipeline_params.py`"
89
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
90
"image pipelines, including prompts and prompt embedding overrides."
91
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
92
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
93
"with non-configurable height and width arguments should set the attribute as "
94
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
95
"See existing pipeline tests for reference."
96
)
97
98
@property
99
def batch_params(self) -> frozenset:
100
raise NotImplementedError(
101
"You need to set the attribute `batch_params` in the child test class. "
102
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
103
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
104
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
105
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
106
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
107
"image pipeline `negative_prompt` is not batched should set the attribute as "
108
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
109
"See existing pipeline tests for reference."
110
)
111
112
def tearDown(self):
113
# clean up the VRAM after each test in case of CUDA runtime errors
114
super().tearDown()
115
gc.collect()
116
torch.cuda.empty_cache()
117
118
def test_save_load_local(self):
119
components = self.get_dummy_components()
120
pipe = self.pipeline_class(**components)
121
pipe.to(torch_device)
122
pipe.set_progress_bar_config(disable=None)
123
124
inputs = self.get_dummy_inputs(torch_device)
125
output = pipe(**inputs)[0]
126
127
with tempfile.TemporaryDirectory() as tmpdir:
128
pipe.save_pretrained(tmpdir)
129
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
130
pipe_loaded.to(torch_device)
131
pipe_loaded.set_progress_bar_config(disable=None)
132
133
inputs = self.get_dummy_inputs(torch_device)
134
output_loaded = pipe_loaded(**inputs)[0]
135
136
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
137
self.assertLess(max_diff, 1e-4)
138
139
def test_pipeline_call_signature(self):
140
self.assertTrue(
141
hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
142
)
143
144
parameters = inspect.signature(self.pipeline_class.__call__).parameters
145
146
optional_parameters = set()
147
148
for k, v in parameters.items():
149
if v.default != inspect._empty:
150
optional_parameters.add(k)
151
152
parameters = set(parameters.keys())
153
parameters.remove("self")
154
parameters.discard("kwargs") # kwargs can be added if arguments of pipeline call function are deprecated
155
156
remaining_required_parameters = set()
157
158
for param in self.params:
159
if param not in parameters:
160
remaining_required_parameters.add(param)
161
162
self.assertTrue(
163
len(remaining_required_parameters) == 0,
164
f"Required parameters not present: {remaining_required_parameters}",
165
)
166
167
remaining_required_optional_parameters = set()
168
169
for param in self.required_optional_params:
170
if param not in optional_parameters:
171
remaining_required_optional_parameters.add(param)
172
173
self.assertTrue(
174
len(remaining_required_optional_parameters) == 0,
175
f"Required optional parameters not present: {remaining_required_optional_parameters}",
176
)
177
178
def test_inference_batch_consistent(self):
179
self._test_inference_batch_consistent()
180
181
def _test_inference_batch_consistent(
182
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
183
):
184
components = self.get_dummy_components()
185
pipe = self.pipeline_class(**components)
186
pipe.to(torch_device)
187
pipe.set_progress_bar_config(disable=None)
188
189
inputs = self.get_dummy_inputs(torch_device)
190
191
logger = logging.get_logger(pipe.__module__)
192
logger.setLevel(level=diffusers.logging.FATAL)
193
194
# batchify inputs
195
for batch_size in batch_sizes:
196
batched_inputs = {}
197
for name, value in inputs.items():
198
if name in self.batch_params:
199
# prompt is string
200
if name == "prompt":
201
len_prompt = len(value)
202
# make unequal batch sizes
203
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
204
205
# make last batch super long
206
batched_inputs[name][-1] = 2000 * "very long"
207
# or else we have images
208
else:
209
batched_inputs[name] = batch_size * [value]
210
elif name == "batch_size":
211
batched_inputs[name] = batch_size
212
else:
213
batched_inputs[name] = value
214
215
for arg in additional_params_copy_to_batched_inputs:
216
batched_inputs[arg] = inputs[arg]
217
218
batched_inputs["output_type"] = None
219
220
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
221
batched_inputs.pop("output_type")
222
223
output = pipe(**batched_inputs)
224
225
assert len(output[0]) == batch_size
226
227
batched_inputs["output_type"] = "np"
228
229
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
230
batched_inputs.pop("output_type")
231
232
output = pipe(**batched_inputs)[0]
233
234
assert output.shape[0] == batch_size
235
236
logger.setLevel(level=diffusers.logging.WARNING)
237
238
def test_inference_batch_single_identical(self):
239
self._test_inference_batch_single_identical()
240
241
def _test_inference_batch_single_identical(
242
self,
243
test_max_difference=None,
244
test_mean_pixel_difference=None,
245
relax_max_difference=False,
246
expected_max_diff=1e-4,
247
additional_params_copy_to_batched_inputs=["num_inference_steps"],
248
):
249
if test_max_difference is None:
250
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
251
# make sure that batched and non-batched is identical
252
test_max_difference = torch_device != "mps"
253
254
if test_mean_pixel_difference is None:
255
# TODO same as above
256
test_mean_pixel_difference = torch_device != "mps"
257
258
components = self.get_dummy_components()
259
pipe = self.pipeline_class(**components)
260
pipe.to(torch_device)
261
pipe.set_progress_bar_config(disable=None)
262
263
inputs = self.get_dummy_inputs(torch_device)
264
265
logger = logging.get_logger(pipe.__module__)
266
logger.setLevel(level=diffusers.logging.FATAL)
267
268
# batchify inputs
269
batched_inputs = {}
270
batch_size = 3
271
for name, value in inputs.items():
272
if name in self.batch_params:
273
# prompt is string
274
if name == "prompt":
275
len_prompt = len(value)
276
# make unequal batch sizes
277
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
278
279
# make last batch super long
280
batched_inputs[name][-1] = 2000 * "very long"
281
# or else we have images
282
else:
283
batched_inputs[name] = batch_size * [value]
284
elif name == "batch_size":
285
batched_inputs[name] = batch_size
286
elif name == "generator":
287
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
288
else:
289
batched_inputs[name] = value
290
291
for arg in additional_params_copy_to_batched_inputs:
292
batched_inputs[arg] = inputs[arg]
293
294
if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
295
batched_inputs["output_type"] = "np"
296
297
output_batch = pipe(**batched_inputs)
298
assert output_batch[0].shape[0] == batch_size
299
300
inputs["generator"] = self.get_generator(0)
301
302
output = pipe(**inputs)
303
304
logger.setLevel(level=diffusers.logging.WARNING)
305
if test_max_difference:
306
if relax_max_difference:
307
# Taking the median of the largest <n> differences
308
# is resilient to outliers
309
diff = np.abs(output_batch[0][0] - output[0][0])
310
diff = diff.flatten()
311
diff.sort()
312
max_diff = np.median(diff[-5:])
313
else:
314
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
315
assert max_diff < expected_max_diff
316
317
if test_mean_pixel_difference:
318
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
319
320
def test_dict_tuple_outputs_equivalent(self):
321
components = self.get_dummy_components()
322
pipe = self.pipeline_class(**components)
323
pipe.to(torch_device)
324
pipe.set_progress_bar_config(disable=None)
325
326
output = pipe(**self.get_dummy_inputs(torch_device))[0]
327
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
328
329
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
330
self.assertLess(max_diff, 1e-4)
331
332
def test_components_function(self):
333
init_components = self.get_dummy_components()
334
pipe = self.pipeline_class(**init_components)
335
336
self.assertTrue(hasattr(pipe, "components"))
337
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
338
339
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
340
def test_float16_inference(self):
341
components = self.get_dummy_components()
342
pipe = self.pipeline_class(**components)
343
pipe.to(torch_device)
344
pipe.set_progress_bar_config(disable=None)
345
346
pipe_fp16 = self.pipeline_class(**components)
347
pipe_fp16.to(torch_device, torch.float16)
348
pipe_fp16.set_progress_bar_config(disable=None)
349
350
output = pipe(**self.get_dummy_inputs(torch_device))[0]
351
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
352
353
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
354
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
355
356
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
357
def test_save_load_float16(self):
358
components = self.get_dummy_components()
359
for name, module in components.items():
360
if hasattr(module, "half"):
361
components[name] = module.to(torch_device).half()
362
pipe = self.pipeline_class(**components)
363
pipe.to(torch_device)
364
pipe.set_progress_bar_config(disable=None)
365
366
inputs = self.get_dummy_inputs(torch_device)
367
output = pipe(**inputs)[0]
368
369
with tempfile.TemporaryDirectory() as tmpdir:
370
pipe.save_pretrained(tmpdir)
371
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
372
pipe_loaded.to(torch_device)
373
pipe_loaded.set_progress_bar_config(disable=None)
374
375
for name, component in pipe_loaded.components.items():
376
if hasattr(component, "dtype"):
377
self.assertTrue(
378
component.dtype == torch.float16,
379
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
380
)
381
382
inputs = self.get_dummy_inputs(torch_device)
383
output_loaded = pipe_loaded(**inputs)[0]
384
385
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
386
self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
387
388
def test_save_load_optional_components(self):
389
if not hasattr(self.pipeline_class, "_optional_components"):
390
return
391
392
components = self.get_dummy_components()
393
pipe = self.pipeline_class(**components)
394
pipe.to(torch_device)
395
pipe.set_progress_bar_config(disable=None)
396
397
# set all optional components to None
398
for optional_component in pipe._optional_components:
399
setattr(pipe, optional_component, None)
400
401
inputs = self.get_dummy_inputs(torch_device)
402
output = pipe(**inputs)[0]
403
404
with tempfile.TemporaryDirectory() as tmpdir:
405
pipe.save_pretrained(tmpdir)
406
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
407
pipe_loaded.to(torch_device)
408
pipe_loaded.set_progress_bar_config(disable=None)
409
410
for optional_component in pipe._optional_components:
411
self.assertTrue(
412
getattr(pipe_loaded, optional_component) is None,
413
f"`{optional_component}` did not stay set to None after loading.",
414
)
415
416
inputs = self.get_dummy_inputs(torch_device)
417
output_loaded = pipe_loaded(**inputs)[0]
418
419
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
420
self.assertLess(max_diff, 1e-4)
421
422
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
423
def test_to_device(self):
424
components = self.get_dummy_components()
425
pipe = self.pipeline_class(**components)
426
pipe.set_progress_bar_config(disable=None)
427
428
pipe.to("cpu")
429
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
430
self.assertTrue(all(device == "cpu" for device in model_devices))
431
432
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
433
self.assertTrue(np.isnan(output_cpu).sum() == 0)
434
435
pipe.to("cuda")
436
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
437
self.assertTrue(all(device == "cuda" for device in model_devices))
438
439
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
440
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
441
442
def test_to_dtype(self):
443
components = self.get_dummy_components()
444
pipe = self.pipeline_class(**components)
445
pipe.set_progress_bar_config(disable=None)
446
447
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
448
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
449
450
pipe.to(torch_dtype=torch.float16)
451
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
452
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
453
454
def test_attention_slicing_forward_pass(self):
455
self._test_attention_slicing_forward_pass()
456
457
def _test_attention_slicing_forward_pass(
458
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
459
):
460
if not self.test_attention_slicing:
461
return
462
463
components = self.get_dummy_components()
464
pipe = self.pipeline_class(**components)
465
pipe.to(torch_device)
466
pipe.set_progress_bar_config(disable=None)
467
468
inputs = self.get_dummy_inputs(torch_device)
469
output_without_slicing = pipe(**inputs)[0]
470
471
pipe.enable_attention_slicing(slice_size=1)
472
inputs = self.get_dummy_inputs(torch_device)
473
output_with_slicing = pipe(**inputs)[0]
474
475
if test_max_difference:
476
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
477
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
478
479
if test_mean_pixel_difference:
480
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
481
482
@unittest.skipIf(
483
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
484
reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
485
)
486
def test_cpu_offload_forward_pass(self):
487
if not self.test_cpu_offload:
488
return
489
490
components = self.get_dummy_components()
491
pipe = self.pipeline_class(**components)
492
pipe.to(torch_device)
493
pipe.set_progress_bar_config(disable=None)
494
495
inputs = self.get_dummy_inputs(torch_device)
496
output_without_offload = pipe(**inputs)[0]
497
498
pipe.enable_sequential_cpu_offload()
499
inputs = self.get_dummy_inputs(torch_device)
500
output_with_offload = pipe(**inputs)[0]
501
502
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
503
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
504
505
@unittest.skipIf(
506
torch_device != "cuda" or not is_xformers_available(),
507
reason="XFormers attention is only available with CUDA and `xformers` installed",
508
)
509
def test_xformers_attention_forwardGenerator_pass(self):
510
self._test_xformers_attention_forwardGenerator_pass()
511
512
def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4):
513
if not self.test_xformers_attention:
514
return
515
516
components = self.get_dummy_components()
517
pipe = self.pipeline_class(**components)
518
pipe.to(torch_device)
519
pipe.set_progress_bar_config(disable=None)
520
521
inputs = self.get_dummy_inputs(torch_device)
522
output_without_offload = pipe(**inputs)[0]
523
524
pipe.enable_xformers_memory_efficient_attention()
525
inputs = self.get_dummy_inputs(torch_device)
526
output_with_offload = pipe(**inputs)[0]
527
528
if test_max_difference:
529
max_diff = np.abs(output_with_offload - output_without_offload).max()
530
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
531
532
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
533
534
def test_progress_bar(self):
535
components = self.get_dummy_components()
536
pipe = self.pipeline_class(**components)
537
pipe.to(torch_device)
538
539
inputs = self.get_dummy_inputs(torch_device)
540
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
541
_ = pipe(**inputs)
542
stderr = stderr.getvalue()
543
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
544
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
545
max_steps = re.search("/(.*?) ", stderr).group(1)
546
self.assertTrue(max_steps is not None and len(max_steps) > 0)
547
self.assertTrue(
548
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
549
)
550
551
pipe.set_progress_bar_config(disable=True)
552
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
553
_ = pipe(**inputs)
554
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
555
556
def test_num_images_per_prompt(self):
557
sig = inspect.signature(self.pipeline_class.__call__)
558
559
if "num_images_per_prompt" not in sig.parameters:
560
return
561
562
components = self.get_dummy_components()
563
pipe = self.pipeline_class(**components)
564
pipe = pipe.to(torch_device)
565
pipe.set_progress_bar_config(disable=None)
566
567
batch_sizes = [1, 2]
568
num_images_per_prompts = [1, 2]
569
570
for batch_size in batch_sizes:
571
for num_images_per_prompt in num_images_per_prompts:
572
inputs = self.get_dummy_inputs(torch_device)
573
574
for key in inputs.keys():
575
if key in self.batch_params:
576
inputs[key] = batch_size * [inputs[key]]
577
578
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
579
580
assert images.shape[0] == batch_size * num_images_per_prompt
581
582
583
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
584
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
585
# reference image.
586
def assert_mean_pixel_difference(image, expected_image):
587
image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32)
588
expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
589
avg_diff = np.abs(image - expected_image).mean()
590
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
591
592