Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_modeling_common.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import inspect
17
import tempfile
18
import unittest
19
import unittest.mock as mock
20
from typing import Dict, List, Tuple
21
22
import numpy as np
23
import requests_mock
24
import torch
25
from requests.exceptions import HTTPError
26
27
from diffusers.models import UNet2DConditionModel
28
from diffusers.models.attention_processor import AttnProcessor
29
from diffusers.training_utils import EMAModel
30
from diffusers.utils import torch_device
31
32
33
class ModelUtilsTest(unittest.TestCase):
34
def tearDown(self):
35
super().tearDown()
36
37
import diffusers
38
39
diffusers.utils.import_utils._safetensors_available = True
40
41
def test_accelerate_loading_error_message(self):
42
with self.assertRaises(ValueError) as error_context:
43
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
44
45
# make sure that error message states what keys are missing
46
assert "conv_out.bias" in str(error_context.exception)
47
48
def test_cached_files_are_used_when_no_internet(self):
49
# A mock response for an HTTP head request to emulate server down
50
response_mock = mock.Mock()
51
response_mock.status_code = 500
52
response_mock.headers = {}
53
response_mock.raise_for_status.side_effect = HTTPError
54
response_mock.json.return_value = {}
55
56
# Download this model to make sure it's in the cache.
57
orig_model = UNet2DConditionModel.from_pretrained(
58
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
59
)
60
61
# Under the mock environment we get a 500 error when trying to reach the model.
62
with mock.patch("requests.request", return_value=response_mock):
63
# Download this model to make sure it's in the cache.
64
model = UNet2DConditionModel.from_pretrained(
65
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
66
)
67
68
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
69
if p1.data.ne(p2.data).sum() > 0:
70
assert False, "Parameters not the same!"
71
72
def test_one_request_upon_cached(self):
73
# TODO: For some reason this test fails on MPS where no HEAD call is made.
74
if torch_device == "mps":
75
return
76
77
import diffusers
78
79
diffusers.utils.import_utils._safetensors_available = False
80
81
with tempfile.TemporaryDirectory() as tmpdirname:
82
with requests_mock.mock(real_http=True) as m:
83
UNet2DConditionModel.from_pretrained(
84
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
85
)
86
87
download_requests = [r.method for r in m.request_history]
88
assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
89
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
90
91
with requests_mock.mock(real_http=True) as m:
92
UNet2DConditionModel.from_pretrained(
93
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
94
)
95
96
cache_requests = [r.method for r in m.request_history]
97
assert (
98
"HEAD" == cache_requests[0] and len(cache_requests) == 1
99
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
100
101
diffusers.utils.import_utils._safetensors_available = True
102
103
104
class ModelTesterMixin:
105
def test_from_save_pretrained(self):
106
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
107
108
model = self.model_class(**init_dict)
109
if hasattr(model, "set_attn_processor"):
110
model.set_attn_processor(AttnProcessor())
111
model.to(torch_device)
112
model.eval()
113
114
with tempfile.TemporaryDirectory() as tmpdirname:
115
model.save_pretrained(tmpdirname)
116
new_model = self.model_class.from_pretrained(tmpdirname)
117
if hasattr(new_model, "set_attn_processor"):
118
new_model.set_attn_processor(AttnProcessor())
119
new_model.to(torch_device)
120
121
with torch.no_grad():
122
image = model(**inputs_dict)
123
if isinstance(image, dict):
124
image = image.sample
125
126
new_image = new_model(**inputs_dict)
127
128
if isinstance(new_image, dict):
129
new_image = new_image.sample
130
131
max_diff = (image - new_image).abs().sum().item()
132
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
133
134
def test_from_save_pretrained_variant(self):
135
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
136
137
model = self.model_class(**init_dict)
138
if hasattr(model, "set_attn_processor"):
139
model.set_attn_processor(AttnProcessor())
140
model.to(torch_device)
141
model.eval()
142
143
with tempfile.TemporaryDirectory() as tmpdirname:
144
model.save_pretrained(tmpdirname, variant="fp16")
145
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
146
if hasattr(new_model, "set_attn_processor"):
147
new_model.set_attn_processor(AttnProcessor())
148
149
# non-variant cannot be loaded
150
with self.assertRaises(OSError) as error_context:
151
self.model_class.from_pretrained(tmpdirname)
152
153
# make sure that error message states what keys are missing
154
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
155
156
new_model.to(torch_device)
157
158
with torch.no_grad():
159
image = model(**inputs_dict)
160
if isinstance(image, dict):
161
image = image.sample
162
163
new_image = new_model(**inputs_dict)
164
165
if isinstance(new_image, dict):
166
new_image = new_image.sample
167
168
max_diff = (image - new_image).abs().sum().item()
169
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
170
171
def test_from_save_pretrained_dtype(self):
172
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
173
174
model = self.model_class(**init_dict)
175
model.to(torch_device)
176
model.eval()
177
178
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
179
if torch_device == "mps" and dtype == torch.bfloat16:
180
continue
181
with tempfile.TemporaryDirectory() as tmpdirname:
182
model.to(dtype)
183
model.save_pretrained(tmpdirname)
184
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
185
assert new_model.dtype == dtype
186
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
187
assert new_model.dtype == dtype
188
189
def test_determinism(self):
190
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
191
model = self.model_class(**init_dict)
192
model.to(torch_device)
193
model.eval()
194
195
with torch.no_grad():
196
first = model(**inputs_dict)
197
if isinstance(first, dict):
198
first = first.sample
199
200
second = model(**inputs_dict)
201
if isinstance(second, dict):
202
second = second.sample
203
204
out_1 = first.cpu().numpy()
205
out_2 = second.cpu().numpy()
206
out_1 = out_1[~np.isnan(out_1)]
207
out_2 = out_2[~np.isnan(out_2)]
208
max_diff = np.amax(np.abs(out_1 - out_2))
209
self.assertLessEqual(max_diff, 1e-5)
210
211
def test_output(self):
212
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
213
model = self.model_class(**init_dict)
214
model.to(torch_device)
215
model.eval()
216
217
with torch.no_grad():
218
output = model(**inputs_dict)
219
220
if isinstance(output, dict):
221
output = output.sample
222
223
self.assertIsNotNone(output)
224
expected_shape = inputs_dict["sample"].shape
225
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
226
227
def test_forward_with_norm_groups(self):
228
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
229
230
init_dict["norm_num_groups"] = 16
231
init_dict["block_out_channels"] = (16, 32)
232
233
model = self.model_class(**init_dict)
234
model.to(torch_device)
235
model.eval()
236
237
with torch.no_grad():
238
output = model(**inputs_dict)
239
240
if isinstance(output, dict):
241
output = output.sample
242
243
self.assertIsNotNone(output)
244
expected_shape = inputs_dict["sample"].shape
245
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
246
247
def test_forward_signature(self):
248
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
249
250
model = self.model_class(**init_dict)
251
signature = inspect.signature(model.forward)
252
# signature.parameters is an OrderedDict => so arg_names order is deterministic
253
arg_names = [*signature.parameters.keys()]
254
255
expected_arg_names = ["sample", "timestep"]
256
self.assertListEqual(arg_names[:2], expected_arg_names)
257
258
def test_model_from_pretrained(self):
259
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
260
261
model = self.model_class(**init_dict)
262
model.to(torch_device)
263
model.eval()
264
265
# test if the model can be loaded from the config
266
# and has all the expected shape
267
with tempfile.TemporaryDirectory() as tmpdirname:
268
model.save_pretrained(tmpdirname)
269
new_model = self.model_class.from_pretrained(tmpdirname)
270
new_model.to(torch_device)
271
new_model.eval()
272
273
# check if all parameters shape are the same
274
for param_name in model.state_dict().keys():
275
param_1 = model.state_dict()[param_name]
276
param_2 = new_model.state_dict()[param_name]
277
self.assertEqual(param_1.shape, param_2.shape)
278
279
with torch.no_grad():
280
output_1 = model(**inputs_dict)
281
282
if isinstance(output_1, dict):
283
output_1 = output_1.sample
284
285
output_2 = new_model(**inputs_dict)
286
287
if isinstance(output_2, dict):
288
output_2 = output_2.sample
289
290
self.assertEqual(output_1.shape, output_2.shape)
291
292
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
293
def test_training(self):
294
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
295
296
model = self.model_class(**init_dict)
297
model.to(torch_device)
298
model.train()
299
output = model(**inputs_dict)
300
301
if isinstance(output, dict):
302
output = output.sample
303
304
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
305
loss = torch.nn.functional.mse_loss(output, noise)
306
loss.backward()
307
308
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
309
def test_ema_training(self):
310
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
311
312
model = self.model_class(**init_dict)
313
model.to(torch_device)
314
model.train()
315
ema_model = EMAModel(model.parameters())
316
317
output = model(**inputs_dict)
318
319
if isinstance(output, dict):
320
output = output.sample
321
322
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
323
loss = torch.nn.functional.mse_loss(output, noise)
324
loss.backward()
325
ema_model.step(model.parameters())
326
327
def test_outputs_equivalence(self):
328
def set_nan_tensor_to_zero(t):
329
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
330
# Track progress in https://github.com/pytorch/pytorch/issues/77764
331
device = t.device
332
if device.type == "mps":
333
t = t.to("cpu")
334
t[t != t] = 0
335
return t.to(device)
336
337
def recursive_check(tuple_object, dict_object):
338
if isinstance(tuple_object, (List, Tuple)):
339
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
340
recursive_check(tuple_iterable_value, dict_iterable_value)
341
elif isinstance(tuple_object, Dict):
342
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
343
recursive_check(tuple_iterable_value, dict_iterable_value)
344
elif tuple_object is None:
345
return
346
else:
347
self.assertTrue(
348
torch.allclose(
349
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
350
),
351
msg=(
352
"Tuple and dict output are not equal. Difference:"
353
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
354
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
355
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
356
),
357
)
358
359
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
360
361
model = self.model_class(**init_dict)
362
model.to(torch_device)
363
model.eval()
364
365
with torch.no_grad():
366
outputs_dict = model(**inputs_dict)
367
outputs_tuple = model(**inputs_dict, return_dict=False)
368
369
recursive_check(outputs_tuple, outputs_dict)
370
371
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
372
def test_enable_disable_gradient_checkpointing(self):
373
if not self.model_class._supports_gradient_checkpointing:
374
return # Skip test if model does not support gradient checkpointing
375
376
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
377
378
# at init model should have gradient checkpointing disabled
379
model = self.model_class(**init_dict)
380
self.assertFalse(model.is_gradient_checkpointing)
381
382
# check enable works
383
model.enable_gradient_checkpointing()
384
self.assertTrue(model.is_gradient_checkpointing)
385
386
# check disable works
387
model.disable_gradient_checkpointing()
388
self.assertFalse(model.is_gradient_checkpointing)
389
390
def test_deprecated_kwargs(self):
391
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
392
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
393
394
if has_kwarg_in_model_class and not has_deprecated_kwarg:
395
raise ValueError(
396
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
397
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
398
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
399
" [<deprecated_argument>]`"
400
)
401
402
if not has_kwarg_in_model_class and has_deprecated_kwarg:
403
raise ValueError(
404
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
405
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
406
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
407
" from `_deprecated_kwargs = [<deprecated_argument>]`"
408
)
409
410