Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_pipelines.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import json
18
import os
19
import random
20
import shutil
21
import sys
22
import tempfile
23
import unittest
24
import unittest.mock as mock
25
26
import numpy as np
27
import PIL
28
import requests_mock
29
import safetensors.torch
30
import torch
31
from parameterized import parameterized
32
from PIL import Image
33
from requests.exceptions import HTTPError
34
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
35
36
from diffusers import (
37
AutoencoderKL,
38
DDIMPipeline,
39
DDIMScheduler,
40
DDPMPipeline,
41
DDPMScheduler,
42
DiffusionPipeline,
43
DPMSolverMultistepScheduler,
44
EulerAncestralDiscreteScheduler,
45
EulerDiscreteScheduler,
46
LMSDiscreteScheduler,
47
PNDMScheduler,
48
StableDiffusionImg2ImgPipeline,
49
StableDiffusionInpaintPipelineLegacy,
50
StableDiffusionPipeline,
51
UNet2DConditionModel,
52
UNet2DModel,
53
UniPCMultistepScheduler,
54
logging,
55
)
56
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
57
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
58
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
59
60
61
torch.backends.cuda.matmul.allow_tf32 = False
62
63
64
class DownloadTests(unittest.TestCase):
65
def test_one_request_upon_cached(self):
66
# TODO: For some reason this test fails on MPS where no HEAD call is made.
67
if torch_device == "mps":
68
return
69
70
with tempfile.TemporaryDirectory() as tmpdirname:
71
with requests_mock.mock(real_http=True) as m:
72
DiffusionPipeline.download(
73
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
74
)
75
76
download_requests = [r.method for r in m.request_history]
77
assert download_requests.count("HEAD") == 15, "15 calls to files"
78
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
79
assert (
80
len(download_requests) == 32
81
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
82
83
with requests_mock.mock(real_http=True) as m:
84
DiffusionPipeline.download(
85
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
86
)
87
88
cache_requests = [r.method for r in m.request_history]
89
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
90
assert cache_requests.count("GET") == 1, "model info is only GET"
91
assert (
92
len(cache_requests) == 2
93
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
94
95
def test_download_only_pytorch(self):
96
with tempfile.TemporaryDirectory() as tmpdirname:
97
# pipeline has Flax weights
98
tmpdirname = DiffusionPipeline.download(
99
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
100
)
101
102
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
103
files = [item for sublist in all_root_files for item in sublist]
104
105
# None of the downloaded files should be a flax file even if we have some here:
106
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
107
assert not any(f.endswith(".msgpack") for f in files)
108
# We need to never convert this tiny model to safetensors for this test to pass
109
assert not any(f.endswith(".safetensors") for f in files)
110
111
def test_force_safetensors_error(self):
112
with tempfile.TemporaryDirectory() as tmpdirname:
113
# pipeline has Flax weights
114
with self.assertRaises(EnvironmentError):
115
tmpdirname = DiffusionPipeline.download(
116
"hf-internal-testing/tiny-stable-diffusion-pipe-no-safetensors",
117
safety_checker=None,
118
cache_dir=tmpdirname,
119
use_safetensors=True,
120
)
121
122
def test_returned_cached_folder(self):
123
prompt = "hello"
124
pipe = StableDiffusionPipeline.from_pretrained(
125
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
126
)
127
_, local_path = StableDiffusionPipeline.from_pretrained(
128
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, return_cached_folder=True
129
)
130
pipe_2 = StableDiffusionPipeline.from_pretrained(local_path)
131
132
pipe = pipe.to(torch_device)
133
pipe_2 = pipe_2.to(torch_device)
134
135
generator = torch.manual_seed(0)
136
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
137
138
generator = torch.manual_seed(0)
139
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
140
141
assert np.max(np.abs(out - out_2)) < 1e-3
142
143
def test_download_safetensors(self):
144
with tempfile.TemporaryDirectory() as tmpdirname:
145
# pipeline has Flax weights
146
tmpdirname = DiffusionPipeline.download(
147
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
148
safety_checker=None,
149
cache_dir=tmpdirname,
150
)
151
152
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
153
files = [item for sublist in all_root_files for item in sublist]
154
155
# None of the downloaded files should be a pytorch file even if we have some here:
156
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
157
assert not any(f.endswith(".bin") for f in files)
158
159
def test_download_no_safety_checker(self):
160
prompt = "hello"
161
pipe = StableDiffusionPipeline.from_pretrained(
162
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
163
)
164
pipe = pipe.to(torch_device)
165
generator = torch.manual_seed(0)
166
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
167
168
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
169
pipe_2 = pipe_2.to(torch_device)
170
generator = torch.manual_seed(0)
171
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
172
173
assert np.max(np.abs(out - out_2)) < 1e-3
174
175
def test_load_no_safety_checker_explicit_locally(self):
176
prompt = "hello"
177
pipe = StableDiffusionPipeline.from_pretrained(
178
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
179
)
180
pipe = pipe.to(torch_device)
181
generator = torch.manual_seed(0)
182
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
183
184
with tempfile.TemporaryDirectory() as tmpdirname:
185
pipe.save_pretrained(tmpdirname)
186
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
187
pipe_2 = pipe_2.to(torch_device)
188
189
generator = torch.manual_seed(0)
190
191
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
192
193
assert np.max(np.abs(out - out_2)) < 1e-3
194
195
def test_load_no_safety_checker_default_locally(self):
196
prompt = "hello"
197
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
198
pipe = pipe.to(torch_device)
199
200
generator = torch.manual_seed(0)
201
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
202
203
with tempfile.TemporaryDirectory() as tmpdirname:
204
pipe.save_pretrained(tmpdirname)
205
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
206
pipe_2 = pipe_2.to(torch_device)
207
208
generator = torch.manual_seed(0)
209
210
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
211
212
assert np.max(np.abs(out - out_2)) < 1e-3
213
214
def test_cached_files_are_used_when_no_internet(self):
215
# A mock response for an HTTP head request to emulate server down
216
response_mock = mock.Mock()
217
response_mock.status_code = 500
218
response_mock.headers = {}
219
response_mock.raise_for_status.side_effect = HTTPError
220
response_mock.json.return_value = {}
221
222
# Download this model to make sure it's in the cache.
223
orig_pipe = StableDiffusionPipeline.from_pretrained(
224
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
225
)
226
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
227
228
# Under the mock environment we get a 500 error when trying to reach the model.
229
with mock.patch("requests.request", return_value=response_mock):
230
# Download this model to make sure it's in the cache.
231
pipe = StableDiffusionPipeline.from_pretrained(
232
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True
233
)
234
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
235
236
for m1, m2 in zip(orig_comps.values(), comps.values()):
237
for p1, p2 in zip(m1.parameters(), m2.parameters()):
238
if p1.data.ne(p2.data).sum() > 0:
239
assert False, "Parameters not the same!"
240
241
def test_download_from_variant_folder(self):
242
for safe_avail in [False, True]:
243
import diffusers
244
245
diffusers.utils.import_utils._safetensors_available = safe_avail
246
247
other_format = ".bin" if safe_avail else ".safetensors"
248
with tempfile.TemporaryDirectory() as tmpdirname:
249
tmpdirname = StableDiffusionPipeline.download(
250
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname
251
)
252
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
253
files = [item for sublist in all_root_files for item in sublist]
254
255
# None of the downloaded files should be a variant file even if we have some here:
256
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
257
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
258
assert not any(f.endswith(other_format) for f in files)
259
# no variants
260
assert not any(len(f.split(".")) == 3 for f in files)
261
262
diffusers.utils.import_utils._safetensors_available = True
263
264
def test_download_variant_all(self):
265
for safe_avail in [False, True]:
266
import diffusers
267
268
diffusers.utils.import_utils._safetensors_available = safe_avail
269
270
other_format = ".bin" if safe_avail else ".safetensors"
271
this_format = ".safetensors" if safe_avail else ".bin"
272
variant = "fp16"
273
274
with tempfile.TemporaryDirectory() as tmpdirname:
275
tmpdirname = StableDiffusionPipeline.download(
276
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
277
)
278
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
279
files = [item for sublist in all_root_files for item in sublist]
280
281
# None of the downloaded files should be a non-variant file even if we have some here:
282
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
283
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
284
# unet, vae, text_encoder, safety_checker
285
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 4
286
# all checkpoints should have variant ending
287
assert not any(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files)
288
assert not any(f.endswith(other_format) for f in files)
289
290
diffusers.utils.import_utils._safetensors_available = True
291
292
def test_download_variant_partly(self):
293
for safe_avail in [False, True]:
294
import diffusers
295
296
diffusers.utils.import_utils._safetensors_available = safe_avail
297
298
other_format = ".bin" if safe_avail else ".safetensors"
299
this_format = ".safetensors" if safe_avail else ".bin"
300
variant = "no_ema"
301
302
with tempfile.TemporaryDirectory() as tmpdirname:
303
tmpdirname = StableDiffusionPipeline.download(
304
"hf-internal-testing/stable-diffusion-all-variants", cache_dir=tmpdirname, variant=variant
305
)
306
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
307
files = [item for sublist in all_root_files for item in sublist]
308
309
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
310
311
# Some of the downloaded files should be a non-variant file, check:
312
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
313
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
314
# only unet has "no_ema" variant
315
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
316
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
317
# vae, safety_checker and text_encoder should have no variant
318
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
319
assert not any(f.endswith(other_format) for f in files)
320
321
diffusers.utils.import_utils._safetensors_available = True
322
323
def test_download_broken_variant(self):
324
for safe_avail in [False, True]:
325
import diffusers
326
327
diffusers.utils.import_utils._safetensors_available = safe_avail
328
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
329
for variant in [None, "no_ema"]:
330
with self.assertRaises(OSError) as error_context:
331
with tempfile.TemporaryDirectory() as tmpdirname:
332
tmpdirname = StableDiffusionPipeline.from_pretrained(
333
"hf-internal-testing/stable-diffusion-broken-variants",
334
cache_dir=tmpdirname,
335
variant=variant,
336
)
337
338
assert "Error no file name" in str(error_context.exception)
339
340
# text encoder has fp16 variants so we can load it
341
with tempfile.TemporaryDirectory() as tmpdirname:
342
tmpdirname = StableDiffusionPipeline.download(
343
"hf-internal-testing/stable-diffusion-broken-variants", cache_dir=tmpdirname, variant="fp16"
344
)
345
346
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
347
files = [item for sublist in all_root_files for item in sublist]
348
349
# None of the downloaded files should be a non-variant file even if we have some here:
350
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
351
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
352
# only unet has "no_ema" variant
353
354
diffusers.utils.import_utils._safetensors_available = True
355
356
357
class CustomPipelineTests(unittest.TestCase):
358
def test_load_custom_pipeline(self):
359
pipeline = DiffusionPipeline.from_pretrained(
360
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
361
)
362
pipeline = pipeline.to(torch_device)
363
# NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
364
# under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
365
assert pipeline.__class__.__name__ == "CustomPipeline"
366
367
def test_load_custom_github(self):
368
pipeline = DiffusionPipeline.from_pretrained(
369
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="main"
370
)
371
372
# make sure that on "main" pipeline gives only ones because of: https://github.com/huggingface/diffusers/pull/1690
373
with torch.no_grad():
374
output = pipeline()
375
376
assert output.numel() == output.sum()
377
378
# hack since Python doesn't like overwriting modules: https://stackoverflow.com/questions/3105801/unload-a-module-in-python
379
# Could in the future work with hashes instead.
380
del sys.modules["diffusers_modules.git.one_step_unet"]
381
382
pipeline = DiffusionPipeline.from_pretrained(
383
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="0.10.2"
384
)
385
with torch.no_grad():
386
output = pipeline()
387
388
assert output.numel() != output.sum()
389
390
assert pipeline.__class__.__name__ == "UnetSchedulerOneForwardPipeline"
391
392
def test_run_custom_pipeline(self):
393
pipeline = DiffusionPipeline.from_pretrained(
394
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
395
)
396
pipeline = pipeline.to(torch_device)
397
images, output_str = pipeline(num_inference_steps=2, output_type="np")
398
399
assert images[0].shape == (1, 32, 32, 3)
400
401
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
402
assert output_str == "This is a test"
403
404
def test_local_custom_pipeline_repo(self):
405
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
406
pipeline = DiffusionPipeline.from_pretrained(
407
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
408
)
409
pipeline = pipeline.to(torch_device)
410
images, output_str = pipeline(num_inference_steps=2, output_type="np")
411
412
assert pipeline.__class__.__name__ == "CustomLocalPipeline"
413
assert images[0].shape == (1, 32, 32, 3)
414
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
415
assert output_str == "This is a local test"
416
417
def test_local_custom_pipeline_file(self):
418
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
419
local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py")
420
pipeline = DiffusionPipeline.from_pretrained(
421
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
422
)
423
pipeline = pipeline.to(torch_device)
424
images, output_str = pipeline(num_inference_steps=2, output_type="np")
425
426
assert pipeline.__class__.__name__ == "CustomLocalPipeline"
427
assert images[0].shape == (1, 32, 32, 3)
428
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
429
assert output_str == "This is a local test"
430
431
@slow
432
@require_torch_gpu
433
def test_download_from_git(self):
434
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
435
436
feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)
437
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
438
439
pipeline = DiffusionPipeline.from_pretrained(
440
"CompVis/stable-diffusion-v1-4",
441
custom_pipeline="clip_guided_stable_diffusion",
442
clip_model=clip_model,
443
feature_extractor=feature_extractor,
444
torch_dtype=torch.float16,
445
)
446
pipeline.enable_attention_slicing()
447
pipeline = pipeline.to(torch_device)
448
449
# NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under:
450
# https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py
451
assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion"
452
453
image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0]
454
assert image.shape == (512, 512, 3)
455
456
457
class PipelineFastTests(unittest.TestCase):
458
def tearDown(self):
459
# clean up the VRAM after each test
460
super().tearDown()
461
gc.collect()
462
torch.cuda.empty_cache()
463
464
import diffusers
465
466
diffusers.utils.import_utils._safetensors_available = True
467
468
def dummy_image(self):
469
batch_size = 1
470
num_channels = 3
471
sizes = (32, 32)
472
473
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
474
return image
475
476
def dummy_uncond_unet(self, sample_size=32):
477
torch.manual_seed(0)
478
model = UNet2DModel(
479
block_out_channels=(32, 64),
480
layers_per_block=2,
481
sample_size=sample_size,
482
in_channels=3,
483
out_channels=3,
484
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
485
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
486
)
487
return model
488
489
def dummy_cond_unet(self, sample_size=32):
490
torch.manual_seed(0)
491
model = UNet2DConditionModel(
492
block_out_channels=(32, 64),
493
layers_per_block=2,
494
sample_size=sample_size,
495
in_channels=4,
496
out_channels=4,
497
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
498
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
499
cross_attention_dim=32,
500
)
501
return model
502
503
@property
504
def dummy_vae(self):
505
torch.manual_seed(0)
506
model = AutoencoderKL(
507
block_out_channels=[32, 64],
508
in_channels=3,
509
out_channels=3,
510
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
511
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
512
latent_channels=4,
513
)
514
return model
515
516
@property
517
def dummy_text_encoder(self):
518
torch.manual_seed(0)
519
config = CLIPTextConfig(
520
bos_token_id=0,
521
eos_token_id=2,
522
hidden_size=32,
523
intermediate_size=37,
524
layer_norm_eps=1e-05,
525
num_attention_heads=4,
526
num_hidden_layers=5,
527
pad_token_id=1,
528
vocab_size=1000,
529
)
530
return CLIPTextModel(config)
531
532
@property
533
def dummy_extractor(self):
534
def extract(*args, **kwargs):
535
class Out:
536
def __init__(self):
537
self.pixel_values = torch.ones([0])
538
539
def to(self, device):
540
self.pixel_values.to(device)
541
return self
542
543
return Out()
544
545
return extract
546
547
@parameterized.expand(
548
[
549
[DDIMScheduler, DDIMPipeline, 32],
550
[DDPMScheduler, DDPMPipeline, 32],
551
[DDIMScheduler, DDIMPipeline, (32, 64)],
552
[DDPMScheduler, DDPMPipeline, (64, 32)],
553
]
554
)
555
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
556
unet = self.dummy_uncond_unet(sample_size)
557
scheduler = scheduler_fn()
558
pipeline = pipeline_fn(unet, scheduler).to(torch_device)
559
560
generator = torch.manual_seed(0)
561
out_image = pipeline(
562
generator=generator,
563
num_inference_steps=2,
564
output_type="np",
565
).images
566
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
567
assert out_image.shape == (1, *sample_size, 3)
568
569
def test_stable_diffusion_components(self):
570
"""Test that components property works correctly"""
571
unet = self.dummy_cond_unet()
572
scheduler = PNDMScheduler(skip_prk_steps=True)
573
vae = self.dummy_vae
574
bert = self.dummy_text_encoder
575
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
576
577
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
578
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
579
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
580
581
# make sure here that pndm scheduler skips prk
582
inpaint = StableDiffusionInpaintPipelineLegacy(
583
unet=unet,
584
scheduler=scheduler,
585
vae=vae,
586
text_encoder=bert,
587
tokenizer=tokenizer,
588
safety_checker=None,
589
feature_extractor=self.dummy_extractor,
590
).to(torch_device)
591
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
592
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
593
594
prompt = "A painting of a squirrel eating a burger"
595
596
generator = torch.manual_seed(0)
597
image_inpaint = inpaint(
598
[prompt],
599
generator=generator,
600
num_inference_steps=2,
601
output_type="np",
602
image=init_image,
603
mask_image=mask_image,
604
).images
605
image_img2img = img2img(
606
[prompt],
607
generator=generator,
608
num_inference_steps=2,
609
output_type="np",
610
image=init_image,
611
).images
612
image_text2img = text2img(
613
[prompt],
614
generator=generator,
615
num_inference_steps=2,
616
output_type="np",
617
).images
618
619
assert image_inpaint.shape == (1, 32, 32, 3)
620
assert image_img2img.shape == (1, 32, 32, 3)
621
assert image_text2img.shape == (1, 64, 64, 3)
622
623
@require_torch_gpu
624
def test_pipe_false_offload_warn(self):
625
unet = self.dummy_cond_unet()
626
scheduler = PNDMScheduler(skip_prk_steps=True)
627
vae = self.dummy_vae
628
bert = self.dummy_text_encoder
629
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
630
631
sd = StableDiffusionPipeline(
632
unet=unet,
633
scheduler=scheduler,
634
vae=vae,
635
text_encoder=bert,
636
tokenizer=tokenizer,
637
safety_checker=None,
638
feature_extractor=self.dummy_extractor,
639
)
640
641
sd.enable_model_cpu_offload()
642
643
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
644
with CaptureLogger(logger) as cap_logger:
645
sd.to("cuda")
646
647
assert "It is strongly recommended against doing so" in str(cap_logger)
648
649
sd = StableDiffusionPipeline(
650
unet=unet,
651
scheduler=scheduler,
652
vae=vae,
653
text_encoder=bert,
654
tokenizer=tokenizer,
655
safety_checker=None,
656
feature_extractor=self.dummy_extractor,
657
)
658
659
def test_set_scheduler(self):
660
unet = self.dummy_cond_unet()
661
scheduler = PNDMScheduler(skip_prk_steps=True)
662
vae = self.dummy_vae
663
bert = self.dummy_text_encoder
664
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
665
666
sd = StableDiffusionPipeline(
667
unet=unet,
668
scheduler=scheduler,
669
vae=vae,
670
text_encoder=bert,
671
tokenizer=tokenizer,
672
safety_checker=None,
673
feature_extractor=self.dummy_extractor,
674
)
675
676
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
677
assert isinstance(sd.scheduler, DDIMScheduler)
678
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
679
assert isinstance(sd.scheduler, DDPMScheduler)
680
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
681
assert isinstance(sd.scheduler, PNDMScheduler)
682
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
683
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
684
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
685
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
686
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
687
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
688
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
689
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
690
691
def test_set_scheduler_consistency(self):
692
unet = self.dummy_cond_unet()
693
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
694
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
695
vae = self.dummy_vae
696
bert = self.dummy_text_encoder
697
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
698
699
sd = StableDiffusionPipeline(
700
unet=unet,
701
scheduler=pndm,
702
vae=vae,
703
text_encoder=bert,
704
tokenizer=tokenizer,
705
safety_checker=None,
706
feature_extractor=self.dummy_extractor,
707
)
708
709
pndm_config = sd.scheduler.config
710
sd.scheduler = DDPMScheduler.from_config(pndm_config)
711
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
712
pndm_config_2 = sd.scheduler.config
713
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
714
715
assert dict(pndm_config) == dict(pndm_config_2)
716
717
sd = StableDiffusionPipeline(
718
unet=unet,
719
scheduler=ddim,
720
vae=vae,
721
text_encoder=bert,
722
tokenizer=tokenizer,
723
safety_checker=None,
724
feature_extractor=self.dummy_extractor,
725
)
726
727
ddim_config = sd.scheduler.config
728
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
729
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
730
ddim_config_2 = sd.scheduler.config
731
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
732
733
assert dict(ddim_config) == dict(ddim_config_2)
734
735
def test_save_safe_serialization(self):
736
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
737
with tempfile.TemporaryDirectory() as tmpdirname:
738
pipeline.save_pretrained(tmpdirname, safe_serialization=True)
739
740
# Validate that the VAE safetensor exists and are of the correct format
741
vae_path = os.path.join(tmpdirname, "vae", "diffusion_pytorch_model.safetensors")
742
assert os.path.exists(vae_path), f"Could not find {vae_path}"
743
_ = safetensors.torch.load_file(vae_path)
744
745
# Validate that the UNet safetensor exists and are of the correct format
746
unet_path = os.path.join(tmpdirname, "unet", "diffusion_pytorch_model.safetensors")
747
assert os.path.exists(unet_path), f"Could not find {unet_path}"
748
_ = safetensors.torch.load_file(unet_path)
749
750
# Validate that the text encoder safetensor exists and are of the correct format
751
text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors")
752
assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}"
753
_ = safetensors.torch.load_file(text_encoder_path)
754
755
pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname)
756
assert pipeline.unet is not None
757
assert pipeline.vae is not None
758
assert pipeline.text_encoder is not None
759
assert pipeline.scheduler is not None
760
assert pipeline.feature_extractor is not None
761
762
def test_no_pytorch_download_when_doing_safetensors(self):
763
# by default we don't download
764
with tempfile.TemporaryDirectory() as tmpdirname:
765
_ = StableDiffusionPipeline.from_pretrained(
766
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
767
)
768
769
path = os.path.join(
770
tmpdirname,
771
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
772
"snapshots",
773
"07838d72e12f9bcec1375b0482b80c1d399be843",
774
"unet",
775
)
776
# safetensors exists
777
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
778
# pytorch does not
779
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
780
781
def test_no_safetensors_download_when_doing_pytorch(self):
782
# mock diffusers safetensors not available
783
import diffusers
784
785
diffusers.utils.import_utils._safetensors_available = False
786
787
with tempfile.TemporaryDirectory() as tmpdirname:
788
_ = StableDiffusionPipeline.from_pretrained(
789
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
790
)
791
792
path = os.path.join(
793
tmpdirname,
794
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
795
"snapshots",
796
"07838d72e12f9bcec1375b0482b80c1d399be843",
797
"unet",
798
)
799
# safetensors does not exists
800
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
801
# pytorch does
802
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
803
804
diffusers.utils.import_utils._safetensors_available = True
805
806
def test_optional_components(self):
807
unet = self.dummy_cond_unet()
808
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
809
vae = self.dummy_vae
810
bert = self.dummy_text_encoder
811
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
812
813
orig_sd = StableDiffusionPipeline(
814
unet=unet,
815
scheduler=pndm,
816
vae=vae,
817
text_encoder=bert,
818
tokenizer=tokenizer,
819
safety_checker=unet,
820
feature_extractor=self.dummy_extractor,
821
)
822
sd = orig_sd
823
824
assert sd.config.requires_safety_checker is True
825
826
with tempfile.TemporaryDirectory() as tmpdirname:
827
sd.save_pretrained(tmpdirname)
828
829
# Test that passing None works
830
sd = StableDiffusionPipeline.from_pretrained(
831
tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False
832
)
833
834
assert sd.config.requires_safety_checker is False
835
assert sd.config.safety_checker == (None, None)
836
assert sd.config.feature_extractor == (None, None)
837
838
with tempfile.TemporaryDirectory() as tmpdirname:
839
sd.save_pretrained(tmpdirname)
840
841
# Test that loading previous None works
842
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
843
844
assert sd.config.requires_safety_checker is False
845
assert sd.config.safety_checker == (None, None)
846
assert sd.config.feature_extractor == (None, None)
847
848
orig_sd.save_pretrained(tmpdirname)
849
850
# Test that loading without any directory works
851
shutil.rmtree(os.path.join(tmpdirname, "safety_checker"))
852
with open(os.path.join(tmpdirname, sd.config_name)) as f:
853
config = json.load(f)
854
config["safety_checker"] = [None, None]
855
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
856
json.dump(config, f)
857
858
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False)
859
sd.save_pretrained(tmpdirname)
860
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
861
862
assert sd.config.requires_safety_checker is False
863
assert sd.config.safety_checker == (None, None)
864
assert sd.config.feature_extractor == (None, None)
865
866
# Test that loading from deleted model index works
867
with open(os.path.join(tmpdirname, sd.config_name)) as f:
868
config = json.load(f)
869
del config["safety_checker"]
870
del config["feature_extractor"]
871
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
872
json.dump(config, f)
873
874
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
875
876
assert sd.config.requires_safety_checker is False
877
assert sd.config.safety_checker == (None, None)
878
assert sd.config.feature_extractor == (None, None)
879
880
with tempfile.TemporaryDirectory() as tmpdirname:
881
sd.save_pretrained(tmpdirname)
882
883
# Test that partially loading works
884
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
885
886
assert sd.config.requires_safety_checker is False
887
assert sd.config.safety_checker == (None, None)
888
assert sd.config.feature_extractor != (None, None)
889
890
# Test that partially loading works
891
sd = StableDiffusionPipeline.from_pretrained(
892
tmpdirname,
893
feature_extractor=self.dummy_extractor,
894
safety_checker=unet,
895
requires_safety_checker=[True, True],
896
)
897
898
assert sd.config.requires_safety_checker == [True, True]
899
assert sd.config.safety_checker != (None, None)
900
assert sd.config.feature_extractor != (None, None)
901
902
with tempfile.TemporaryDirectory() as tmpdirname:
903
sd.save_pretrained(tmpdirname)
904
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
905
906
assert sd.config.requires_safety_checker == [True, True]
907
assert sd.config.safety_checker != (None, None)
908
assert sd.config.feature_extractor != (None, None)
909
910
911
@slow
912
@require_torch_gpu
913
class PipelineSlowTests(unittest.TestCase):
914
def tearDown(self):
915
# clean up the VRAM after each test
916
super().tearDown()
917
gc.collect()
918
torch.cuda.empty_cache()
919
920
def test_smart_download(self):
921
model_id = "hf-internal-testing/unet-pipeline-dummy"
922
with tempfile.TemporaryDirectory() as tmpdirname:
923
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
924
local_repo_name = "--".join(["models"] + model_id.split("/"))
925
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
926
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
927
928
# inspect all downloaded files to make sure that everything is included
929
assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
930
assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
931
assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
932
assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
933
assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
934
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
935
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
936
# let's make sure the super large numpy file:
937
# https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
938
# is not downloaded, but all the expected ones
939
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
940
941
def test_warning_unused_kwargs(self):
942
model_id = "hf-internal-testing/unet-pipeline-dummy"
943
logger = logging.get_logger("diffusers.pipelines")
944
with tempfile.TemporaryDirectory() as tmpdirname:
945
with CaptureLogger(logger) as cap_logger:
946
DiffusionPipeline.from_pretrained(
947
model_id,
948
not_used=True,
949
cache_dir=tmpdirname,
950
force_download=True,
951
)
952
953
assert (
954
cap_logger.out.strip().split("\n")[-1]
955
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored."
956
)
957
958
def test_from_save_pretrained(self):
959
# 1. Load models
960
model = UNet2DModel(
961
block_out_channels=(32, 64),
962
layers_per_block=2,
963
sample_size=32,
964
in_channels=3,
965
out_channels=3,
966
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
967
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
968
)
969
schedular = DDPMScheduler(num_train_timesteps=10)
970
971
ddpm = DDPMPipeline(model, schedular)
972
ddpm.to(torch_device)
973
ddpm.set_progress_bar_config(disable=None)
974
975
with tempfile.TemporaryDirectory() as tmpdirname:
976
ddpm.save_pretrained(tmpdirname)
977
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
978
new_ddpm.to(torch_device)
979
980
generator = torch.Generator(device=torch_device).manual_seed(0)
981
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
982
983
generator = torch.Generator(device=torch_device).manual_seed(0)
984
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
985
986
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
987
988
def test_from_pretrained_hub(self):
989
model_path = "google/ddpm-cifar10-32"
990
991
scheduler = DDPMScheduler(num_train_timesteps=10)
992
993
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
994
ddpm = ddpm.to(torch_device)
995
ddpm.set_progress_bar_config(disable=None)
996
997
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
998
ddpm_from_hub = ddpm_from_hub.to(torch_device)
999
ddpm_from_hub.set_progress_bar_config(disable=None)
1000
1001
generator = torch.Generator(device=torch_device).manual_seed(0)
1002
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
1003
1004
generator = torch.Generator(device=torch_device).manual_seed(0)
1005
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
1006
1007
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
1008
1009
def test_from_pretrained_hub_pass_model(self):
1010
model_path = "google/ddpm-cifar10-32"
1011
1012
scheduler = DDPMScheduler(num_train_timesteps=10)
1013
1014
# pass unet into DiffusionPipeline
1015
unet = UNet2DModel.from_pretrained(model_path)
1016
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
1017
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
1018
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
1019
1020
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
1021
ddpm_from_hub = ddpm_from_hub.to(torch_device)
1022
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
1023
1024
generator = torch.Generator(device=torch_device).manual_seed(0)
1025
image = ddpm_from_hub_custom_model(generator=generator, num_inference_steps=5, output_type="numpy").images
1026
1027
generator = torch.Generator(device=torch_device).manual_seed(0)
1028
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
1029
1030
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
1031
1032
def test_output_format(self):
1033
model_path = "google/ddpm-cifar10-32"
1034
1035
scheduler = DDIMScheduler.from_pretrained(model_path)
1036
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
1037
pipe.to(torch_device)
1038
pipe.set_progress_bar_config(disable=None)
1039
1040
images = pipe(output_type="numpy").images
1041
assert images.shape == (1, 32, 32, 3)
1042
assert isinstance(images, np.ndarray)
1043
1044
images = pipe(output_type="pil", num_inference_steps=4).images
1045
assert isinstance(images, list)
1046
assert len(images) == 1
1047
assert isinstance(images[0], PIL.Image.Image)
1048
1049
# use PIL by default
1050
images = pipe(num_inference_steps=4).images
1051
assert isinstance(images, list)
1052
assert isinstance(images[0], PIL.Image.Image)
1053
1054
def test_from_flax_from_pt(self):
1055
pipe_pt = StableDiffusionPipeline.from_pretrained(
1056
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
1057
)
1058
pipe_pt.to(torch_device)
1059
1060
if not is_flax_available():
1061
raise ImportError("Make sure flax is installed.")
1062
1063
from diffusers import FlaxStableDiffusionPipeline
1064
1065
with tempfile.TemporaryDirectory() as tmpdirname:
1066
pipe_pt.save_pretrained(tmpdirname)
1067
1068
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
1069
tmpdirname, safety_checker=None, from_pt=True
1070
)
1071
1072
with tempfile.TemporaryDirectory() as tmpdirname:
1073
pipe_flax.save_pretrained(tmpdirname, params=params)
1074
pipe_pt_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None, from_flax=True)
1075
pipe_pt_2.to(torch_device)
1076
1077
prompt = "Hello"
1078
1079
generator = torch.manual_seed(0)
1080
image_0 = pipe_pt(
1081
[prompt],
1082
generator=generator,
1083
num_inference_steps=2,
1084
output_type="np",
1085
).images[0]
1086
1087
generator = torch.manual_seed(0)
1088
image_1 = pipe_pt_2(
1089
[prompt],
1090
generator=generator,
1091
num_inference_steps=2,
1092
output_type="np",
1093
).images[0]
1094
1095
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
1096
1097
@require_compel
1098
def test_weighted_prompts_compel(self):
1099
from compel import Compel
1100
1101
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
1102
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
1103
pipe.enable_model_cpu_offload()
1104
pipe.enable_attention_slicing()
1105
1106
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
1107
1108
prompt = "a red cat playing with a ball{}"
1109
1110
prompts = [prompt.format(s) for s in ["", "++", "--"]]
1111
1112
prompt_embeds = compel(prompts)
1113
1114
generator = [torch.Generator(device="cpu").manual_seed(33) for _ in range(prompt_embeds.shape[0])]
1115
1116
images = pipe(
1117
prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20, output_type="numpy"
1118
).images
1119
1120
for i, image in enumerate(images):
1121
expected_image = load_numpy(
1122
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1123
f"/compel/forest_{i}.npy"
1124
)
1125
1126
assert np.abs(image - expected_image).max() < 1e-2
1127
1128
1129
@nightly
1130
@require_torch_gpu
1131
class PipelineNightlyTests(unittest.TestCase):
1132
def tearDown(self):
1133
# clean up the VRAM after each test
1134
super().tearDown()
1135
gc.collect()
1136
torch.cuda.empty_cache()
1137
1138
def test_ddpm_ddim_equality_batched(self):
1139
seed = 0
1140
model_id = "google/ddpm-cifar10-32"
1141
1142
unet = UNet2DModel.from_pretrained(model_id)
1143
ddpm_scheduler = DDPMScheduler()
1144
ddim_scheduler = DDIMScheduler()
1145
1146
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
1147
ddpm.to(torch_device)
1148
ddpm.set_progress_bar_config(disable=None)
1149
1150
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
1151
ddim.to(torch_device)
1152
ddim.set_progress_bar_config(disable=None)
1153
1154
generator = torch.Generator(device=torch_device).manual_seed(seed)
1155
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
1156
1157
generator = torch.Generator(device=torch_device).manual_seed(seed)
1158
ddim_images = ddim(
1159
batch_size=2,
1160
generator=generator,
1161
num_inference_steps=1000,
1162
eta=1.0,
1163
output_type="numpy",
1164
use_clipped_model_output=True, # Need this to make DDIM match DDPM
1165
).images
1166
1167
# the values aren't exactly equal, but the images look the same visually
1168
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
1169
1170