Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/ddpm/test_ddpm.py
1450 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import unittest
17
18
import numpy as np
19
import torch
20
21
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
22
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
23
24
25
torch.backends.cuda.matmul.allow_tf32 = False
26
27
28
class DDPMPipelineFastTests(unittest.TestCase):
29
@property
30
def dummy_uncond_unet(self):
31
torch.manual_seed(0)
32
model = UNet2DModel(
33
block_out_channels=(32, 64),
34
layers_per_block=2,
35
sample_size=32,
36
in_channels=3,
37
out_channels=3,
38
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
39
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
40
)
41
return model
42
43
def test_fast_inference(self):
44
device = "cpu"
45
unet = self.dummy_uncond_unet
46
scheduler = DDPMScheduler()
47
48
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
49
ddpm.to(device)
50
ddpm.set_progress_bar_config(disable=None)
51
52
generator = torch.Generator(device=device).manual_seed(0)
53
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
54
55
generator = torch.Generator(device=device).manual_seed(0)
56
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
57
58
image_slice = image[0, -3:, -3:, -1]
59
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
60
61
assert image.shape == (1, 32, 32, 3)
62
expected_slice = np.array(
63
[9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04]
64
)
65
66
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
67
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
68
69
def test_inference_predict_sample(self):
70
unet = self.dummy_uncond_unet
71
scheduler = DDPMScheduler(prediction_type="sample")
72
73
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
74
ddpm.to(torch_device)
75
ddpm.set_progress_bar_config(disable=None)
76
77
generator = torch.manual_seed(0)
78
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
79
80
generator = torch.manual_seed(0)
81
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
82
83
image_slice = image[0, -3:, -3:, -1]
84
image_eps_slice = image_eps[0, -3:, -3:, -1]
85
86
assert image.shape == (1, 32, 32, 3)
87
tolerance = 1e-2 if torch_device != "mps" else 3e-2
88
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
89
90
91
@slow
92
@require_torch_gpu
93
class DDPMPipelineIntegrationTests(unittest.TestCase):
94
def test_inference_cifar10(self):
95
model_id = "google/ddpm-cifar10-32"
96
97
unet = UNet2DModel.from_pretrained(model_id)
98
scheduler = DDPMScheduler.from_pretrained(model_id)
99
100
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
101
ddpm.to(torch_device)
102
ddpm.set_progress_bar_config(disable=None)
103
104
generator = torch.manual_seed(0)
105
image = ddpm(generator=generator, output_type="numpy").images
106
107
image_slice = image[0, -3:, -3:, -1]
108
109
assert image.shape == (1, 32, 32, 3)
110
expected_slice = np.array([0.4200, 0.3588, 0.1939, 0.3847, 0.3382, 0.2647, 0.4155, 0.3582, 0.3385])
111
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
112
113