Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_unet_2d_blocks.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import unittest
16
17
from diffusers.models.unet_2d_blocks import * # noqa F403
18
from diffusers.utils import torch_device
19
20
from .test_unet_blocks_common import UNetBlockTesterMixin
21
22
23
class DownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
24
block_class = DownBlock2D # noqa F405
25
block_type = "down"
26
27
def test_output(self):
28
expected_slice = [-0.0232, -0.9869, 0.8054, -0.0637, -0.1688, -1.4264, 0.4470, -1.3394, 0.0904]
29
super().test_output(expected_slice)
30
31
32
class ResnetDownsampleBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
33
block_class = ResnetDownsampleBlock2D # noqa F405
34
block_type = "down"
35
36
def test_output(self):
37
expected_slice = [0.0710, 0.2410, -0.7320, -1.0757, -1.1343, 0.3540, -0.0133, -0.2576, 0.0948]
38
super().test_output(expected_slice)
39
40
41
class AttnDownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
42
block_class = AttnDownBlock2D # noqa F405
43
block_type = "down"
44
45
def test_output(self):
46
expected_slice = [0.0636, 0.8964, -0.6234, -1.0131, 0.0844, 0.4935, 0.3437, 0.0911, -0.2957]
47
super().test_output(expected_slice)
48
49
50
class CrossAttnDownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
51
block_class = CrossAttnDownBlock2D # noqa F405
52
block_type = "down"
53
54
def prepare_init_args_and_inputs_for_common(self):
55
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
56
init_dict["cross_attention_dim"] = 32
57
return init_dict, inputs_dict
58
59
def test_output(self):
60
expected_slice = [0.2440, -0.6953, -0.2140, -0.3874, 0.1966, 1.2077, 0.0441, -0.7718, 0.2800]
61
super().test_output(expected_slice)
62
63
64
class SimpleCrossAttnDownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
65
block_class = SimpleCrossAttnDownBlock2D # noqa F405
66
block_type = "down"
67
68
@property
69
def dummy_input(self):
70
return super().get_dummy_input(include_encoder_hidden_states=True)
71
72
def prepare_init_args_and_inputs_for_common(self):
73
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
74
init_dict["cross_attention_dim"] = 32
75
return init_dict, inputs_dict
76
77
@unittest.skipIf(torch_device == "mps", "MPS result is not consistent")
78
def test_output(self):
79
expected_slice = [0.7921, -0.0992, -0.1962, -0.7695, -0.4242, 0.7804, 0.4737, 0.2765, 0.3338]
80
super().test_output(expected_slice)
81
82
83
class SkipDownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
84
block_class = SkipDownBlock2D # noqa F405
85
block_type = "down"
86
87
@property
88
def dummy_input(self):
89
return super().get_dummy_input(include_skip_sample=True)
90
91
def test_output(self):
92
expected_slice = [-0.0845, -0.2087, -0.2465, 0.0971, 0.1900, -0.0484, 0.2664, 0.4179, 0.5069]
93
super().test_output(expected_slice)
94
95
96
class AttnSkipDownBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
97
block_class = AttnSkipDownBlock2D # noqa F405
98
block_type = "down"
99
100
@property
101
def dummy_input(self):
102
return super().get_dummy_input(include_skip_sample=True)
103
104
def test_output(self):
105
expected_slice = [0.5539, 0.1609, 0.4924, 0.0537, -0.1995, 0.4050, 0.0979, -0.2721, -0.0642]
106
super().test_output(expected_slice)
107
108
109
class DownEncoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
110
block_class = DownEncoderBlock2D # noqa F405
111
block_type = "down"
112
113
@property
114
def dummy_input(self):
115
return super().get_dummy_input(include_temb=False)
116
117
def prepare_init_args_and_inputs_for_common(self):
118
init_dict = {
119
"in_channels": 32,
120
"out_channels": 32,
121
}
122
inputs_dict = self.dummy_input
123
return init_dict, inputs_dict
124
125
def test_output(self):
126
expected_slice = [1.1102, 0.5302, 0.4872, -0.0023, -0.8042, 0.0483, -0.3489, -0.5632, 0.7626]
127
super().test_output(expected_slice)
128
129
130
class AttnDownEncoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
131
block_class = AttnDownEncoderBlock2D # noqa F405
132
block_type = "down"
133
134
@property
135
def dummy_input(self):
136
return super().get_dummy_input(include_temb=False)
137
138
def prepare_init_args_and_inputs_for_common(self):
139
init_dict = {
140
"in_channels": 32,
141
"out_channels": 32,
142
}
143
inputs_dict = self.dummy_input
144
return init_dict, inputs_dict
145
146
def test_output(self):
147
expected_slice = [0.8966, -0.1486, 0.8568, 0.8141, -0.9046, -0.1342, -0.0972, -0.7417, 0.1538]
148
super().test_output(expected_slice)
149
150
151
class UNetMidBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
152
block_class = UNetMidBlock2D # noqa F405
153
block_type = "mid"
154
155
def prepare_init_args_and_inputs_for_common(self):
156
init_dict = {
157
"in_channels": 32,
158
"temb_channels": 128,
159
}
160
inputs_dict = self.dummy_input
161
return init_dict, inputs_dict
162
163
def test_output(self):
164
expected_slice = [-0.1062, 1.7248, 0.3494, 1.4569, -0.0910, -1.2421, -0.9984, 0.6736, 1.0028]
165
super().test_output(expected_slice)
166
167
168
class UNetMidBlock2DCrossAttnTests(UNetBlockTesterMixin, unittest.TestCase):
169
block_class = UNetMidBlock2DCrossAttn # noqa F405
170
block_type = "mid"
171
172
def prepare_init_args_and_inputs_for_common(self):
173
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
174
init_dict["cross_attention_dim"] = 32
175
return init_dict, inputs_dict
176
177
def test_output(self):
178
expected_slice = [0.1879, 2.2653, 0.5987, 1.1568, -0.8454, -1.6109, -0.8919, 0.8306, 1.6758]
179
super().test_output(expected_slice)
180
181
182
class UNetMidBlock2DSimpleCrossAttnTests(UNetBlockTesterMixin, unittest.TestCase):
183
block_class = UNetMidBlock2DSimpleCrossAttn # noqa F405
184
block_type = "mid"
185
186
@property
187
def dummy_input(self):
188
return super().get_dummy_input(include_encoder_hidden_states=True)
189
190
def prepare_init_args_and_inputs_for_common(self):
191
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
192
init_dict["cross_attention_dim"] = 32
193
return init_dict, inputs_dict
194
195
def test_output(self):
196
expected_slice = [0.7143, 1.9974, 0.5448, 1.3977, 0.1282, -1.1237, -1.4238, 0.5530, 0.8880]
197
super().test_output(expected_slice)
198
199
200
class UpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
201
block_class = UpBlock2D # noqa F405
202
block_type = "up"
203
204
@property
205
def dummy_input(self):
206
return super().get_dummy_input(include_res_hidden_states_tuple=True)
207
208
def test_output(self):
209
expected_slice = [-0.2041, -0.4165, -0.3022, 0.0041, -0.6628, -0.7053, 0.1928, -0.0325, 0.0523]
210
super().test_output(expected_slice)
211
212
213
class ResnetUpsampleBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
214
block_class = ResnetUpsampleBlock2D # noqa F405
215
block_type = "up"
216
217
@property
218
def dummy_input(self):
219
return super().get_dummy_input(include_res_hidden_states_tuple=True)
220
221
def test_output(self):
222
expected_slice = [0.2287, 0.3549, -0.1346, 0.4797, -0.1715, -0.9649, 0.7305, -0.5864, -0.6244]
223
super().test_output(expected_slice)
224
225
226
class CrossAttnUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
227
block_class = CrossAttnUpBlock2D # noqa F405
228
block_type = "up"
229
230
@property
231
def dummy_input(self):
232
return super().get_dummy_input(include_res_hidden_states_tuple=True)
233
234
def prepare_init_args_and_inputs_for_common(self):
235
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
236
init_dict["cross_attention_dim"] = 32
237
return init_dict, inputs_dict
238
239
def test_output(self):
240
expected_slice = [-0.2796, -0.4364, -0.1067, -0.2693, 0.1894, 0.3869, -0.3470, 0.4584, 0.5091]
241
super().test_output(expected_slice)
242
243
244
class SimpleCrossAttnUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
245
block_class = SimpleCrossAttnUpBlock2D # noqa F405
246
block_type = "up"
247
248
@property
249
def dummy_input(self):
250
return super().get_dummy_input(include_res_hidden_states_tuple=True, include_encoder_hidden_states=True)
251
252
def prepare_init_args_and_inputs_for_common(self):
253
init_dict, inputs_dict = super().prepare_init_args_and_inputs_for_common()
254
init_dict["cross_attention_dim"] = 32
255
return init_dict, inputs_dict
256
257
def test_output(self):
258
expected_slice = [0.2645, 0.1480, 0.0909, 0.8044, -0.9758, -0.9083, 0.0994, -1.1453, -0.7402]
259
super().test_output(expected_slice)
260
261
262
class AttnUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
263
block_class = AttnUpBlock2D # noqa F405
264
block_type = "up"
265
266
@property
267
def dummy_input(self):
268
return super().get_dummy_input(include_res_hidden_states_tuple=True)
269
270
@unittest.skipIf(torch_device == "mps", "MPS result is not consistent")
271
def test_output(self):
272
expected_slice = [0.0979, 0.1326, 0.0021, 0.0659, 0.2249, 0.0059, 0.1132, 0.5952, 0.1033]
273
super().test_output(expected_slice)
274
275
276
class SkipUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
277
block_class = SkipUpBlock2D # noqa F405
278
block_type = "up"
279
280
@property
281
def dummy_input(self):
282
return super().get_dummy_input(include_res_hidden_states_tuple=True)
283
284
def test_output(self):
285
expected_slice = [-0.0893, -0.1234, -0.1506, -0.0332, 0.0123, -0.0211, 0.0566, 0.0143, 0.0362]
286
super().test_output(expected_slice)
287
288
289
class AttnSkipUpBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
290
block_class = AttnSkipUpBlock2D # noqa F405
291
block_type = "up"
292
293
@property
294
def dummy_input(self):
295
return super().get_dummy_input(include_res_hidden_states_tuple=True)
296
297
def test_output(self):
298
expected_slice = [0.0361, 0.0617, 0.2787, -0.0350, 0.0342, 0.3421, -0.0843, 0.0913, 0.3015]
299
super().test_output(expected_slice)
300
301
302
class UpDecoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
303
block_class = UpDecoderBlock2D # noqa F405
304
block_type = "up"
305
306
@property
307
def dummy_input(self):
308
return super().get_dummy_input(include_temb=False)
309
310
def prepare_init_args_and_inputs_for_common(self):
311
init_dict = {"in_channels": 32, "out_channels": 32}
312
313
inputs_dict = self.dummy_input
314
return init_dict, inputs_dict
315
316
def test_output(self):
317
expected_slice = [0.4404, 0.1998, -0.9886, -0.3320, -0.3128, -0.7034, -0.6955, -0.2338, -0.3137]
318
super().test_output(expected_slice)
319
320
321
class AttnUpDecoderBlock2DTests(UNetBlockTesterMixin, unittest.TestCase):
322
block_class = AttnUpDecoderBlock2D # noqa F405
323
block_type = "up"
324
325
@property
326
def dummy_input(self):
327
return super().get_dummy_input(include_temb=False)
328
329
def prepare_init_args_and_inputs_for_common(self):
330
init_dict = {"in_channels": 32, "out_channels": 32}
331
332
inputs_dict = self.dummy_input
333
return init_dict, inputs_dict
334
335
def test_output(self):
336
expected_slice = [0.6738, 0.4491, 0.1055, 1.0710, 0.7316, 0.3339, 0.3352, 0.1023, 0.3568]
337
super().test_output(expected_slice)
338
339