Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/test_layers_utils.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
17
import unittest
18
19
import numpy as np
20
import torch
21
from torch import nn
22
23
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
24
from diffusers.models.embeddings import get_timestep_embedding
25
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
26
from diffusers.models.transformer_2d import Transformer2DModel
27
from diffusers.utils import torch_device
28
29
30
torch.backends.cuda.matmul.allow_tf32 = False
31
32
33
class EmbeddingsTests(unittest.TestCase):
34
def test_timestep_embeddings(self):
35
embedding_dim = 256
36
timesteps = torch.arange(16)
37
38
t1 = get_timestep_embedding(timesteps, embedding_dim)
39
40
# first vector should always be composed only of 0's and 1's
41
assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5
42
assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5
43
44
# last element of each vector should be one
45
assert (t1[:, -1] - 1).abs().sum() < 1e-5
46
47
# For large embeddings (e.g. 128) the frequency of every vector is higher
48
# than the previous one which means that the gradients of later vectors are
49
# ALWAYS higher than the previous ones
50
grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1)
51
52
prev_grad = 0.0
53
for grad in grad_mean:
54
assert grad > prev_grad
55
prev_grad = grad
56
57
def test_timestep_defaults(self):
58
embedding_dim = 16
59
timesteps = torch.arange(10)
60
61
t1 = get_timestep_embedding(timesteps, embedding_dim)
62
t2 = get_timestep_embedding(
63
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
64
)
65
66
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
67
68
def test_timestep_flip_sin_cos(self):
69
embedding_dim = 16
70
timesteps = torch.arange(10)
71
72
t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True)
73
t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1)
74
75
t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False)
76
77
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
78
79
def test_timestep_downscale_freq_shift(self):
80
embedding_dim = 16
81
timesteps = torch.arange(10)
82
83
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0)
84
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1)
85
86
# get cosine half (vectors that are wrapped into cosine)
87
cosine_half = (t1 - t2)[:, embedding_dim // 2 :]
88
89
# cosine needs to be negative
90
assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5
91
92
def test_sinoid_embeddings_hardcoded(self):
93
embedding_dim = 64
94
timesteps = torch.arange(128)
95
96
# standard unet, score_vde
97
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False)
98
# glide, ldm
99
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True)
100
# grad-tts
101
t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000)
102
103
assert torch.allclose(
104
t1[23:26, 47:50].flatten().cpu(),
105
torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]),
106
1e-3,
107
)
108
assert torch.allclose(
109
t2[23:26, 47:50].flatten().cpu(),
110
torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]),
111
1e-3,
112
)
113
assert torch.allclose(
114
t3[23:26, 47:50].flatten().cpu(),
115
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
116
1e-3,
117
)
118
119
120
class Upsample2DBlockTests(unittest.TestCase):
121
def test_upsample_default(self):
122
torch.manual_seed(0)
123
sample = torch.randn(1, 32, 32, 32)
124
upsample = Upsample2D(channels=32, use_conv=False)
125
with torch.no_grad():
126
upsampled = upsample(sample)
127
128
assert upsampled.shape == (1, 32, 64, 64)
129
output_slice = upsampled[0, -1, -3:, -3:]
130
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
131
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
132
133
def test_upsample_with_conv(self):
134
torch.manual_seed(0)
135
sample = torch.randn(1, 32, 32, 32)
136
upsample = Upsample2D(channels=32, use_conv=True)
137
with torch.no_grad():
138
upsampled = upsample(sample)
139
140
assert upsampled.shape == (1, 32, 64, 64)
141
output_slice = upsampled[0, -1, -3:, -3:]
142
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
143
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
144
145
def test_upsample_with_conv_out_dim(self):
146
torch.manual_seed(0)
147
sample = torch.randn(1, 32, 32, 32)
148
upsample = Upsample2D(channels=32, use_conv=True, out_channels=64)
149
with torch.no_grad():
150
upsampled = upsample(sample)
151
152
assert upsampled.shape == (1, 64, 64, 64)
153
output_slice = upsampled[0, -1, -3:, -3:]
154
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
155
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
156
157
def test_upsample_with_transpose(self):
158
torch.manual_seed(0)
159
sample = torch.randn(1, 32, 32, 32)
160
upsample = Upsample2D(channels=32, use_conv=False, use_conv_transpose=True)
161
with torch.no_grad():
162
upsampled = upsample(sample)
163
164
assert upsampled.shape == (1, 32, 64, 64)
165
output_slice = upsampled[0, -1, -3:, -3:]
166
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
167
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
168
169
170
class Downsample2DBlockTests(unittest.TestCase):
171
def test_downsample_default(self):
172
torch.manual_seed(0)
173
sample = torch.randn(1, 32, 64, 64)
174
downsample = Downsample2D(channels=32, use_conv=False)
175
with torch.no_grad():
176
downsampled = downsample(sample)
177
178
assert downsampled.shape == (1, 32, 32, 32)
179
output_slice = downsampled[0, -1, -3:, -3:]
180
expected_slice = torch.tensor([-0.0513, -0.3889, 0.0640, 0.0836, -0.5460, -0.0341, -0.0169, -0.6967, 0.1179])
181
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
182
assert max_diff <= 1e-3
183
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
184
185
def test_downsample_with_conv(self):
186
torch.manual_seed(0)
187
sample = torch.randn(1, 32, 64, 64)
188
downsample = Downsample2D(channels=32, use_conv=True)
189
with torch.no_grad():
190
downsampled = downsample(sample)
191
192
assert downsampled.shape == (1, 32, 32, 32)
193
output_slice = downsampled[0, -1, -3:, -3:]
194
195
expected_slice = torch.tensor(
196
[0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913],
197
)
198
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
199
200
def test_downsample_with_conv_pad1(self):
201
torch.manual_seed(0)
202
sample = torch.randn(1, 32, 64, 64)
203
downsample = Downsample2D(channels=32, use_conv=True, padding=1)
204
with torch.no_grad():
205
downsampled = downsample(sample)
206
207
assert downsampled.shape == (1, 32, 32, 32)
208
output_slice = downsampled[0, -1, -3:, -3:]
209
expected_slice = torch.tensor([0.9267, 0.5878, 0.3337, 1.2321, -0.1191, -0.3984, -0.7532, -0.0715, -0.3913])
210
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
211
212
def test_downsample_with_conv_out_dim(self):
213
torch.manual_seed(0)
214
sample = torch.randn(1, 32, 64, 64)
215
downsample = Downsample2D(channels=32, use_conv=True, out_channels=16)
216
with torch.no_grad():
217
downsampled = downsample(sample)
218
219
assert downsampled.shape == (1, 16, 32, 32)
220
output_slice = downsampled[0, -1, -3:, -3:]
221
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
222
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
223
224
225
class ResnetBlock2DTests(unittest.TestCase):
226
def test_resnet_default(self):
227
torch.manual_seed(0)
228
sample = torch.randn(1, 32, 64, 64).to(torch_device)
229
temb = torch.randn(1, 128).to(torch_device)
230
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device)
231
with torch.no_grad():
232
output_tensor = resnet_block(sample, temb)
233
234
assert output_tensor.shape == (1, 32, 64, 64)
235
output_slice = output_tensor[0, -1, -3:, -3:]
236
expected_slice = torch.tensor(
237
[-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device
238
)
239
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
240
241
def test_restnet_with_use_in_shortcut(self):
242
torch.manual_seed(0)
243
sample = torch.randn(1, 32, 64, 64).to(torch_device)
244
temb = torch.randn(1, 128).to(torch_device)
245
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device)
246
with torch.no_grad():
247
output_tensor = resnet_block(sample, temb)
248
249
assert output_tensor.shape == (1, 32, 64, 64)
250
output_slice = output_tensor[0, -1, -3:, -3:]
251
expected_slice = torch.tensor(
252
[0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device
253
)
254
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
255
256
def test_resnet_up(self):
257
torch.manual_seed(0)
258
sample = torch.randn(1, 32, 64, 64).to(torch_device)
259
temb = torch.randn(1, 128).to(torch_device)
260
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device)
261
with torch.no_grad():
262
output_tensor = resnet_block(sample, temb)
263
264
assert output_tensor.shape == (1, 32, 128, 128)
265
output_slice = output_tensor[0, -1, -3:, -3:]
266
expected_slice = torch.tensor(
267
[1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device
268
)
269
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
270
271
def test_resnet_down(self):
272
torch.manual_seed(0)
273
sample = torch.randn(1, 32, 64, 64).to(torch_device)
274
temb = torch.randn(1, 128).to(torch_device)
275
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device)
276
with torch.no_grad():
277
output_tensor = resnet_block(sample, temb)
278
279
assert output_tensor.shape == (1, 32, 32, 32)
280
output_slice = output_tensor[0, -1, -3:, -3:]
281
expected_slice = torch.tensor(
282
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
283
)
284
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
285
286
def test_restnet_with_kernel_fir(self):
287
torch.manual_seed(0)
288
sample = torch.randn(1, 32, 64, 64).to(torch_device)
289
temb = torch.randn(1, 128).to(torch_device)
290
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device)
291
with torch.no_grad():
292
output_tensor = resnet_block(sample, temb)
293
294
assert output_tensor.shape == (1, 32, 32, 32)
295
output_slice = output_tensor[0, -1, -3:, -3:]
296
expected_slice = torch.tensor(
297
[-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device
298
)
299
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
300
301
def test_restnet_with_kernel_sde_vp(self):
302
torch.manual_seed(0)
303
sample = torch.randn(1, 32, 64, 64).to(torch_device)
304
temb = torch.randn(1, 128).to(torch_device)
305
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device)
306
with torch.no_grad():
307
output_tensor = resnet_block(sample, temb)
308
309
assert output_tensor.shape == (1, 32, 32, 32)
310
output_slice = output_tensor[0, -1, -3:, -3:]
311
expected_slice = torch.tensor(
312
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
313
)
314
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
315
316
317
class AttentionBlockTests(unittest.TestCase):
318
@unittest.skipIf(
319
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
320
)
321
def test_attention_block_default(self):
322
torch.manual_seed(0)
323
if torch.cuda.is_available():
324
torch.cuda.manual_seed_all(0)
325
326
sample = torch.randn(1, 32, 64, 64).to(torch_device)
327
attentionBlock = AttentionBlock(
328
channels=32,
329
num_head_channels=1,
330
rescale_output_factor=1.0,
331
eps=1e-6,
332
norm_num_groups=32,
333
).to(torch_device)
334
with torch.no_grad():
335
attention_scores = attentionBlock(sample)
336
337
assert attention_scores.shape == (1, 32, 64, 64)
338
output_slice = attention_scores[0, -1, -3:, -3:]
339
340
expected_slice = torch.tensor(
341
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
342
)
343
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
344
345
def test_attention_block_sd(self):
346
# This version uses SD params and is compatible with mps
347
torch.manual_seed(0)
348
if torch.cuda.is_available():
349
torch.cuda.manual_seed_all(0)
350
351
sample = torch.randn(1, 512, 64, 64).to(torch_device)
352
attentionBlock = AttentionBlock(
353
channels=512,
354
rescale_output_factor=1.0,
355
eps=1e-6,
356
norm_num_groups=32,
357
).to(torch_device)
358
with torch.no_grad():
359
attention_scores = attentionBlock(sample)
360
361
assert attention_scores.shape == (1, 512, 64, 64)
362
output_slice = attention_scores[0, -1, -3:, -3:]
363
364
expected_slice = torch.tensor(
365
[-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device
366
)
367
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
368
369
370
class Transformer2DModelTests(unittest.TestCase):
371
def test_spatial_transformer_default(self):
372
torch.manual_seed(0)
373
if torch.cuda.is_available():
374
torch.cuda.manual_seed_all(0)
375
376
sample = torch.randn(1, 32, 64, 64).to(torch_device)
377
spatial_transformer_block = Transformer2DModel(
378
in_channels=32,
379
num_attention_heads=1,
380
attention_head_dim=32,
381
dropout=0.0,
382
cross_attention_dim=None,
383
).to(torch_device)
384
with torch.no_grad():
385
attention_scores = spatial_transformer_block(sample).sample
386
387
assert attention_scores.shape == (1, 32, 64, 64)
388
output_slice = attention_scores[0, -1, -3:, -3:]
389
390
expected_slice = torch.tensor(
391
[-1.9455, -0.0066, -1.3933, -1.5878, 0.5325, -0.6486, -1.8648, 0.7515, -0.9689], device=torch_device
392
)
393
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
394
395
def test_spatial_transformer_cross_attention_dim(self):
396
torch.manual_seed(0)
397
if torch.cuda.is_available():
398
torch.cuda.manual_seed_all(0)
399
400
sample = torch.randn(1, 64, 64, 64).to(torch_device)
401
spatial_transformer_block = Transformer2DModel(
402
in_channels=64,
403
num_attention_heads=2,
404
attention_head_dim=32,
405
dropout=0.0,
406
cross_attention_dim=64,
407
).to(torch_device)
408
with torch.no_grad():
409
context = torch.randn(1, 4, 64).to(torch_device)
410
attention_scores = spatial_transformer_block(sample, context).sample
411
412
assert attention_scores.shape == (1, 64, 64, 64)
413
output_slice = attention_scores[0, -1, -3:, -3:]
414
415
expected_slice = torch.tensor(
416
[-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471], device=torch_device
417
)
418
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
419
420
def test_spatial_transformer_timestep(self):
421
torch.manual_seed(0)
422
if torch.cuda.is_available():
423
torch.cuda.manual_seed_all(0)
424
425
num_embeds_ada_norm = 5
426
427
sample = torch.randn(1, 64, 64, 64).to(torch_device)
428
spatial_transformer_block = Transformer2DModel(
429
in_channels=64,
430
num_attention_heads=2,
431
attention_head_dim=32,
432
dropout=0.0,
433
cross_attention_dim=64,
434
num_embeds_ada_norm=num_embeds_ada_norm,
435
).to(torch_device)
436
with torch.no_grad():
437
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
438
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
439
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
440
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
441
442
assert attention_scores_1.shape == (1, 64, 64, 64)
443
assert attention_scores_2.shape == (1, 64, 64, 64)
444
445
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
446
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
447
448
expected_slice_1 = torch.tensor(
449
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
450
)
451
expected_slice_2 = torch.tensor(
452
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
453
)
454
455
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
456
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
457
458
def test_spatial_transformer_dropout(self):
459
torch.manual_seed(0)
460
if torch.cuda.is_available():
461
torch.cuda.manual_seed_all(0)
462
463
sample = torch.randn(1, 32, 64, 64).to(torch_device)
464
spatial_transformer_block = (
465
Transformer2DModel(
466
in_channels=32,
467
num_attention_heads=2,
468
attention_head_dim=16,
469
dropout=0.3,
470
cross_attention_dim=None,
471
)
472
.to(torch_device)
473
.eval()
474
)
475
with torch.no_grad():
476
attention_scores = spatial_transformer_block(sample).sample
477
478
assert attention_scores.shape == (1, 32, 64, 64)
479
output_slice = attention_scores[0, -1, -3:, -3:]
480
481
expected_slice = torch.tensor(
482
[-1.9380, -0.0083, -1.3771, -1.5819, 0.5209, -0.6441, -1.8545, 0.7563, -0.9615], device=torch_device
483
)
484
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
485
486
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
487
def test_spatial_transformer_discrete(self):
488
torch.manual_seed(0)
489
if torch.cuda.is_available():
490
torch.cuda.manual_seed_all(0)
491
492
num_embed = 5
493
494
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
495
spatial_transformer_block = (
496
Transformer2DModel(
497
num_attention_heads=1,
498
attention_head_dim=32,
499
num_vector_embeds=num_embed,
500
sample_size=16,
501
)
502
.to(torch_device)
503
.eval()
504
)
505
506
with torch.no_grad():
507
attention_scores = spatial_transformer_block(sample).sample
508
509
assert attention_scores.shape == (1, num_embed - 1, 32)
510
511
output_slice = attention_scores[0, -2:, -3:]
512
513
expected_slice = torch.tensor([-1.7648, -1.0241, -2.0985, -1.8035, -1.6404, -1.2098], device=torch_device)
514
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
515
516
def test_spatial_transformer_default_norm_layers(self):
517
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
518
519
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
520
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
521
522
def test_spatial_transformer_ada_norm_layers(self):
523
spatial_transformer_block = Transformer2DModel(
524
num_attention_heads=1,
525
attention_head_dim=32,
526
in_channels=32,
527
num_embeds_ada_norm=5,
528
)
529
530
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
531
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
532
533
def test_spatial_transformer_default_ff_layers(self):
534
spatial_transformer_block = Transformer2DModel(
535
num_attention_heads=1,
536
attention_head_dim=32,
537
in_channels=32,
538
)
539
540
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
541
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
542
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
543
544
dim = 32
545
inner_dim = 128
546
547
# First dimension change
548
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
549
# NOTE: inner_dim * 2 because GEGLU
550
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
551
552
# Second dimension change
553
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
554
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
555
556
def test_spatial_transformer_geglu_approx_ff_layers(self):
557
spatial_transformer_block = Transformer2DModel(
558
num_attention_heads=1,
559
attention_head_dim=32,
560
in_channels=32,
561
activation_fn="geglu-approximate",
562
)
563
564
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
565
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
566
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
567
568
dim = 32
569
inner_dim = 128
570
571
# First dimension change
572
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
573
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
574
575
# Second dimension change
576
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
577
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
578
579
def test_spatial_transformer_attention_bias(self):
580
spatial_transformer_block = Transformer2DModel(
581
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
582
)
583
584
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
585
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
586
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None
587
588