Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/models/test_models_unet_2d.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import math
18
import unittest
19
20
import torch
21
22
from diffusers import UNet2DModel
23
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
24
25
from ..test_modeling_common import ModelTesterMixin
26
27
28
logger = logging.get_logger(__name__)
29
torch.backends.cuda.matmul.allow_tf32 = False
30
31
32
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
33
model_class = UNet2DModel
34
35
@property
36
def dummy_input(self):
37
batch_size = 4
38
num_channels = 3
39
sizes = (32, 32)
40
41
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
42
time_step = torch.tensor([10]).to(torch_device)
43
44
return {"sample": noise, "timestep": time_step}
45
46
@property
47
def input_shape(self):
48
return (3, 32, 32)
49
50
@property
51
def output_shape(self):
52
return (3, 32, 32)
53
54
def prepare_init_args_and_inputs_for_common(self):
55
init_dict = {
56
"block_out_channels": (32, 64),
57
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
58
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
59
"attention_head_dim": None,
60
"out_channels": 3,
61
"in_channels": 3,
62
"layers_per_block": 2,
63
"sample_size": 32,
64
}
65
inputs_dict = self.dummy_input
66
return init_dict, inputs_dict
67
68
69
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
70
model_class = UNet2DModel
71
72
@property
73
def dummy_input(self):
74
batch_size = 4
75
num_channels = 4
76
sizes = (32, 32)
77
78
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
79
time_step = torch.tensor([10]).to(torch_device)
80
81
return {"sample": noise, "timestep": time_step}
82
83
@property
84
def input_shape(self):
85
return (4, 32, 32)
86
87
@property
88
def output_shape(self):
89
return (4, 32, 32)
90
91
def prepare_init_args_and_inputs_for_common(self):
92
init_dict = {
93
"sample_size": 32,
94
"in_channels": 4,
95
"out_channels": 4,
96
"layers_per_block": 2,
97
"block_out_channels": (32, 64),
98
"attention_head_dim": 32,
99
"down_block_types": ("DownBlock2D", "DownBlock2D"),
100
"up_block_types": ("UpBlock2D", "UpBlock2D"),
101
}
102
inputs_dict = self.dummy_input
103
return init_dict, inputs_dict
104
105
def test_from_pretrained_hub(self):
106
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
107
108
self.assertIsNotNone(model)
109
self.assertEqual(len(loading_info["missing_keys"]), 0)
110
111
model.to(torch_device)
112
image = model(**self.dummy_input).sample
113
114
assert image is not None, "Make sure output is not None"
115
116
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
117
def test_from_pretrained_accelerate(self):
118
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
119
model.to(torch_device)
120
image = model(**self.dummy_input).sample
121
122
assert image is not None, "Make sure output is not None"
123
124
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
125
def test_from_pretrained_accelerate_wont_change_results(self):
126
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
127
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
128
model_accelerate.to(torch_device)
129
model_accelerate.eval()
130
131
noise = torch.randn(
132
1,
133
model_accelerate.config.in_channels,
134
model_accelerate.config.sample_size,
135
model_accelerate.config.sample_size,
136
generator=torch.manual_seed(0),
137
)
138
noise = noise.to(torch_device)
139
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
140
141
arr_accelerate = model_accelerate(noise, time_step)["sample"]
142
143
# two models don't need to stay in the device at the same time
144
del model_accelerate
145
torch.cuda.empty_cache()
146
gc.collect()
147
148
model_normal_load, _ = UNet2DModel.from_pretrained(
149
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
150
)
151
model_normal_load.to(torch_device)
152
model_normal_load.eval()
153
arr_normal_load = model_normal_load(noise, time_step)["sample"]
154
155
assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)
156
157
def test_output_pretrained(self):
158
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
159
model.eval()
160
model.to(torch_device)
161
162
noise = torch.randn(
163
1,
164
model.config.in_channels,
165
model.config.sample_size,
166
model.config.sample_size,
167
generator=torch.manual_seed(0),
168
)
169
noise = noise.to(torch_device)
170
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
171
172
with torch.no_grad():
173
output = model(noise, time_step).sample
174
175
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
176
# fmt: off
177
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
178
# fmt: on
179
180
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
181
182
183
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
184
model_class = UNet2DModel
185
186
@property
187
def dummy_input(self, sizes=(32, 32)):
188
batch_size = 4
189
num_channels = 3
190
191
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
192
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
193
194
return {"sample": noise, "timestep": time_step}
195
196
@property
197
def input_shape(self):
198
return (3, 32, 32)
199
200
@property
201
def output_shape(self):
202
return (3, 32, 32)
203
204
def prepare_init_args_and_inputs_for_common(self):
205
init_dict = {
206
"block_out_channels": [32, 64, 64, 64],
207
"in_channels": 3,
208
"layers_per_block": 1,
209
"out_channels": 3,
210
"time_embedding_type": "fourier",
211
"norm_eps": 1e-6,
212
"mid_block_scale_factor": math.sqrt(2.0),
213
"norm_num_groups": None,
214
"down_block_types": [
215
"SkipDownBlock2D",
216
"AttnSkipDownBlock2D",
217
"SkipDownBlock2D",
218
"SkipDownBlock2D",
219
],
220
"up_block_types": [
221
"SkipUpBlock2D",
222
"SkipUpBlock2D",
223
"AttnSkipUpBlock2D",
224
"SkipUpBlock2D",
225
],
226
}
227
inputs_dict = self.dummy_input
228
return init_dict, inputs_dict
229
230
@slow
231
def test_from_pretrained_hub(self):
232
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
233
self.assertIsNotNone(model)
234
self.assertEqual(len(loading_info["missing_keys"]), 0)
235
236
model.to(torch_device)
237
inputs = self.dummy_input
238
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
239
inputs["sample"] = noise
240
image = model(**inputs)
241
242
assert image is not None, "Make sure output is not None"
243
244
@slow
245
def test_output_pretrained_ve_mid(self):
246
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
247
model.to(torch_device)
248
249
torch.manual_seed(0)
250
if torch.cuda.is_available():
251
torch.cuda.manual_seed_all(0)
252
253
batch_size = 4
254
num_channels = 3
255
sizes = (256, 256)
256
257
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
258
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
259
260
with torch.no_grad():
261
output = model(noise, time_step).sample
262
263
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
264
# fmt: off
265
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
266
# fmt: on
267
268
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
269
270
def test_output_pretrained_ve_large(self):
271
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
272
model.to(torch_device)
273
274
torch.manual_seed(0)
275
if torch.cuda.is_available():
276
torch.cuda.manual_seed_all(0)
277
278
batch_size = 4
279
num_channels = 3
280
sizes = (32, 32)
281
282
noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
283
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
284
285
with torch.no_grad():
286
output = model(noise, time_step).sample
287
288
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
289
# fmt: off
290
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
291
# fmt: on
292
293
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
294
295
def test_forward_with_norm_groups(self):
296
# not required for this model
297
pass
298
299