Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/models/test_models_vae.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import torch
20
from parameterized import parameterized
21
22
from diffusers import AutoencoderKL
23
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
24
25
from ..test_modeling_common import ModelTesterMixin
26
27
28
torch.backends.cuda.matmul.allow_tf32 = False
29
30
31
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
32
model_class = AutoencoderKL
33
34
@property
35
def dummy_input(self):
36
batch_size = 4
37
num_channels = 3
38
sizes = (32, 32)
39
40
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
41
42
return {"sample": image}
43
44
@property
45
def input_shape(self):
46
return (3, 32, 32)
47
48
@property
49
def output_shape(self):
50
return (3, 32, 32)
51
52
def prepare_init_args_and_inputs_for_common(self):
53
init_dict = {
54
"block_out_channels": [32, 64],
55
"in_channels": 3,
56
"out_channels": 3,
57
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
58
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
59
"latent_channels": 4,
60
}
61
inputs_dict = self.dummy_input
62
return init_dict, inputs_dict
63
64
def test_forward_signature(self):
65
pass
66
67
def test_training(self):
68
pass
69
70
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
71
def test_gradient_checkpointing(self):
72
# enable deterministic behavior for gradient checkpointing
73
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
74
model = self.model_class(**init_dict)
75
model.to(torch_device)
76
77
assert not model.is_gradient_checkpointing and model.training
78
79
out = model(**inputs_dict).sample
80
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
81
# we won't calculate the loss and rather backprop on out.sum()
82
model.zero_grad()
83
84
labels = torch.randn_like(out)
85
loss = (out - labels).mean()
86
loss.backward()
87
88
# re-instantiate the model now enabling gradient checkpointing
89
model_2 = self.model_class(**init_dict)
90
# clone model
91
model_2.load_state_dict(model.state_dict())
92
model_2.to(torch_device)
93
model_2.enable_gradient_checkpointing()
94
95
assert model_2.is_gradient_checkpointing and model_2.training
96
97
out_2 = model_2(**inputs_dict).sample
98
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
99
# we won't calculate the loss and rather backprop on out.sum()
100
model_2.zero_grad()
101
loss_2 = (out_2 - labels).mean()
102
loss_2.backward()
103
104
# compare the output and parameters gradients
105
self.assertTrue((loss - loss_2).abs() < 1e-5)
106
named_params = dict(model.named_parameters())
107
named_params_2 = dict(model_2.named_parameters())
108
for name, param in named_params.items():
109
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
110
111
def test_from_pretrained_hub(self):
112
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
113
self.assertIsNotNone(model)
114
self.assertEqual(len(loading_info["missing_keys"]), 0)
115
116
model.to(torch_device)
117
image = model(**self.dummy_input)
118
119
assert image is not None, "Make sure output is not None"
120
121
def test_output_pretrained(self):
122
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
123
model = model.to(torch_device)
124
model.eval()
125
126
if torch_device == "mps":
127
generator = torch.manual_seed(0)
128
else:
129
generator = torch.Generator(device=torch_device).manual_seed(0)
130
131
image = torch.randn(
132
1,
133
model.config.in_channels,
134
model.config.sample_size,
135
model.config.sample_size,
136
generator=torch.manual_seed(0),
137
)
138
image = image.to(torch_device)
139
with torch.no_grad():
140
output = model(image, sample_posterior=True, generator=generator).sample
141
142
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
143
144
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
145
# the expected output slices are not the same for CPU and GPU.
146
if torch_device == "mps":
147
expected_output_slice = torch.tensor(
148
[
149
-4.0078e-01,
150
-3.8323e-04,
151
-1.2681e-01,
152
-1.1462e-01,
153
2.0095e-01,
154
1.0893e-01,
155
-8.8247e-02,
156
-3.0361e-01,
157
-9.8644e-03,
158
]
159
)
160
elif torch_device == "cpu":
161
expected_output_slice = torch.tensor(
162
[-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
163
)
164
else:
165
expected_output_slice = torch.tensor(
166
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
167
)
168
169
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
170
171
172
@slow
173
class AutoencoderKLIntegrationTests(unittest.TestCase):
174
def get_file_format(self, seed, shape):
175
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
176
177
def tearDown(self):
178
# clean up the VRAM after each test
179
super().tearDown()
180
gc.collect()
181
torch.cuda.empty_cache()
182
183
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
184
dtype = torch.float16 if fp16 else torch.float32
185
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
186
return image
187
188
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
189
revision = "fp16" if fp16 else None
190
torch_dtype = torch.float16 if fp16 else torch.float32
191
192
model = AutoencoderKL.from_pretrained(
193
model_id,
194
subfolder="vae",
195
torch_dtype=torch_dtype,
196
revision=revision,
197
)
198
model.to(torch_device).eval()
199
200
return model
201
202
def get_generator(self, seed=0):
203
if torch_device == "mps":
204
return torch.manual_seed(seed)
205
return torch.Generator(device=torch_device).manual_seed(seed)
206
207
@parameterized.expand(
208
[
209
# fmt: off
210
[33, [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
211
[47, [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
212
# fmt: on
213
]
214
)
215
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
216
model = self.get_sd_vae_model()
217
image = self.get_sd_image(seed)
218
generator = self.get_generator(seed)
219
220
with torch.no_grad():
221
sample = model(image, generator=generator, sample_posterior=True).sample
222
223
assert sample.shape == image.shape
224
225
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
226
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
227
228
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
229
230
@parameterized.expand(
231
[
232
# fmt: off
233
[33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]],
234
[47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]],
235
# fmt: on
236
]
237
)
238
@require_torch_gpu
239
def test_stable_diffusion_fp16(self, seed, expected_slice):
240
model = self.get_sd_vae_model(fp16=True)
241
image = self.get_sd_image(seed, fp16=True)
242
generator = self.get_generator(seed)
243
244
with torch.no_grad():
245
sample = model(image, generator=generator, sample_posterior=True).sample
246
247
assert sample.shape == image.shape
248
249
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
250
expected_output_slice = torch.tensor(expected_slice)
251
252
assert torch_all_close(output_slice, expected_output_slice, atol=1e-2)
253
254
@parameterized.expand(
255
[
256
# fmt: off
257
[33, [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
258
[47, [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
259
# fmt: on
260
]
261
)
262
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
263
model = self.get_sd_vae_model()
264
image = self.get_sd_image(seed)
265
266
with torch.no_grad():
267
sample = model(image).sample
268
269
assert sample.shape == image.shape
270
271
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
272
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
273
274
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
275
276
@parameterized.expand(
277
[
278
# fmt: off
279
[13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]],
280
[37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]],
281
# fmt: on
282
]
283
)
284
@require_torch_gpu
285
def test_stable_diffusion_decode(self, seed, expected_slice):
286
model = self.get_sd_vae_model()
287
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
288
289
with torch.no_grad():
290
sample = model.decode(encoding).sample
291
292
assert list(sample.shape) == [3, 3, 512, 512]
293
294
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
295
expected_output_slice = torch.tensor(expected_slice)
296
297
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
298
299
@parameterized.expand(
300
[
301
# fmt: off
302
[27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]],
303
[16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]],
304
# fmt: on
305
]
306
)
307
@require_torch_gpu
308
def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
309
model = self.get_sd_vae_model(fp16=True)
310
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
311
312
with torch.no_grad():
313
sample = model.decode(encoding).sample
314
315
assert list(sample.shape) == [3, 3, 512, 512]
316
317
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
318
expected_output_slice = torch.tensor(expected_slice)
319
320
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
321
322
@parameterized.expand(
323
[
324
# fmt: off
325
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
326
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
327
# fmt: on
328
]
329
)
330
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
331
model = self.get_sd_vae_model()
332
image = self.get_sd_image(seed)
333
generator = self.get_generator(seed)
334
335
with torch.no_grad():
336
dist = model.encode(image).latent_dist
337
sample = dist.sample(generator=generator)
338
339
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
340
341
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
342
expected_output_slice = torch.tensor(expected_slice)
343
344
tolerance = 1e-3 if torch_device != "mps" else 1e-2
345
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
346
347