Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/models/test_models_unet_2d_condition.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import os
18
import tempfile
19
import unittest
20
21
import torch
22
from parameterized import parameterized
23
24
from diffusers import UNet2DConditionModel
25
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
26
from diffusers.utils import (
27
floats_tensor,
28
load_hf_numpy,
29
logging,
30
require_torch_gpu,
31
slow,
32
torch_all_close,
33
torch_device,
34
)
35
from diffusers.utils.import_utils import is_xformers_available
36
37
from ..test_modeling_common import ModelTesterMixin
38
39
40
logger = logging.get_logger(__name__)
41
torch.backends.cuda.matmul.allow_tf32 = False
42
43
44
def create_lora_layers(model):
45
lora_attn_procs = {}
46
for name in model.attn_processors.keys():
47
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
48
if name.startswith("mid_block"):
49
hidden_size = model.config.block_out_channels[-1]
50
elif name.startswith("up_blocks"):
51
block_id = int(name[len("up_blocks.")])
52
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
53
elif name.startswith("down_blocks"):
54
block_id = int(name[len("down_blocks.")])
55
hidden_size = model.config.block_out_channels[block_id]
56
57
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
58
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
59
60
# add 1 to weights to mock trained weights
61
with torch.no_grad():
62
lora_attn_procs[name].to_q_lora.up.weight += 1
63
lora_attn_procs[name].to_k_lora.up.weight += 1
64
lora_attn_procs[name].to_v_lora.up.weight += 1
65
lora_attn_procs[name].to_out_lora.up.weight += 1
66
67
return lora_attn_procs
68
69
70
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
71
model_class = UNet2DConditionModel
72
73
@property
74
def dummy_input(self):
75
batch_size = 4
76
num_channels = 4
77
sizes = (32, 32)
78
79
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
80
time_step = torch.tensor([10]).to(torch_device)
81
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
82
83
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
84
85
@property
86
def input_shape(self):
87
return (4, 32, 32)
88
89
@property
90
def output_shape(self):
91
return (4, 32, 32)
92
93
def prepare_init_args_and_inputs_for_common(self):
94
init_dict = {
95
"block_out_channels": (32, 64),
96
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
97
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
98
"cross_attention_dim": 32,
99
"attention_head_dim": 8,
100
"out_channels": 4,
101
"in_channels": 4,
102
"layers_per_block": 2,
103
"sample_size": 32,
104
}
105
inputs_dict = self.dummy_input
106
return init_dict, inputs_dict
107
108
@unittest.skipIf(
109
torch_device != "cuda" or not is_xformers_available(),
110
reason="XFormers attention is only available with CUDA and `xformers` installed",
111
)
112
def test_xformers_enable_works(self):
113
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
114
model = self.model_class(**init_dict)
115
116
model.enable_xformers_memory_efficient_attention()
117
118
assert (
119
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
120
== "XFormersAttnProcessor"
121
), "xformers is not enabled"
122
123
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
124
def test_gradient_checkpointing(self):
125
# enable deterministic behavior for gradient checkpointing
126
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
127
model = self.model_class(**init_dict)
128
model.to(torch_device)
129
130
assert not model.is_gradient_checkpointing and model.training
131
132
out = model(**inputs_dict).sample
133
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
134
# we won't calculate the loss and rather backprop on out.sum()
135
model.zero_grad()
136
137
labels = torch.randn_like(out)
138
loss = (out - labels).mean()
139
loss.backward()
140
141
# re-instantiate the model now enabling gradient checkpointing
142
model_2 = self.model_class(**init_dict)
143
# clone model
144
model_2.load_state_dict(model.state_dict())
145
model_2.to(torch_device)
146
model_2.enable_gradient_checkpointing()
147
148
assert model_2.is_gradient_checkpointing and model_2.training
149
150
out_2 = model_2(**inputs_dict).sample
151
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
152
# we won't calculate the loss and rather backprop on out.sum()
153
model_2.zero_grad()
154
loss_2 = (out_2 - labels).mean()
155
loss_2.backward()
156
157
# compare the output and parameters gradients
158
self.assertTrue((loss - loss_2).abs() < 1e-5)
159
named_params = dict(model.named_parameters())
160
named_params_2 = dict(model_2.named_parameters())
161
for name, param in named_params.items():
162
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
163
164
def test_model_with_attention_head_dim_tuple(self):
165
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
166
167
init_dict["attention_head_dim"] = (8, 16)
168
169
model = self.model_class(**init_dict)
170
model.to(torch_device)
171
model.eval()
172
173
with torch.no_grad():
174
output = model(**inputs_dict)
175
176
if isinstance(output, dict):
177
output = output.sample
178
179
self.assertIsNotNone(output)
180
expected_shape = inputs_dict["sample"].shape
181
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
182
183
def test_model_with_use_linear_projection(self):
184
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
185
186
init_dict["use_linear_projection"] = True
187
188
model = self.model_class(**init_dict)
189
model.to(torch_device)
190
model.eval()
191
192
with torch.no_grad():
193
output = model(**inputs_dict)
194
195
if isinstance(output, dict):
196
output = output.sample
197
198
self.assertIsNotNone(output)
199
expected_shape = inputs_dict["sample"].shape
200
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
201
202
def test_model_with_cross_attention_dim_tuple(self):
203
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
204
205
init_dict["cross_attention_dim"] = (32, 32)
206
207
model = self.model_class(**init_dict)
208
model.to(torch_device)
209
model.eval()
210
211
with torch.no_grad():
212
output = model(**inputs_dict)
213
214
if isinstance(output, dict):
215
output = output.sample
216
217
self.assertIsNotNone(output)
218
expected_shape = inputs_dict["sample"].shape
219
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
220
221
def test_model_with_simple_projection(self):
222
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
223
224
batch_size, _, _, sample_size = inputs_dict["sample"].shape
225
226
init_dict["class_embed_type"] = "simple_projection"
227
init_dict["projection_class_embeddings_input_dim"] = sample_size
228
229
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
230
231
model = self.model_class(**init_dict)
232
model.to(torch_device)
233
model.eval()
234
235
with torch.no_grad():
236
output = model(**inputs_dict)
237
238
if isinstance(output, dict):
239
output = output.sample
240
241
self.assertIsNotNone(output)
242
expected_shape = inputs_dict["sample"].shape
243
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
244
245
def test_model_with_class_embeddings_concat(self):
246
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
247
248
batch_size, _, _, sample_size = inputs_dict["sample"].shape
249
250
init_dict["class_embed_type"] = "simple_projection"
251
init_dict["projection_class_embeddings_input_dim"] = sample_size
252
init_dict["class_embeddings_concat"] = True
253
254
inputs_dict["class_labels"] = floats_tensor((batch_size, sample_size)).to(torch_device)
255
256
model = self.model_class(**init_dict)
257
model.to(torch_device)
258
model.eval()
259
260
with torch.no_grad():
261
output = model(**inputs_dict)
262
263
if isinstance(output, dict):
264
output = output.sample
265
266
self.assertIsNotNone(output)
267
expected_shape = inputs_dict["sample"].shape
268
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
269
270
def test_model_attention_slicing(self):
271
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
272
273
init_dict["attention_head_dim"] = (8, 16)
274
275
model = self.model_class(**init_dict)
276
model.to(torch_device)
277
model.eval()
278
279
model.set_attention_slice("auto")
280
with torch.no_grad():
281
output = model(**inputs_dict)
282
assert output is not None
283
284
model.set_attention_slice("max")
285
with torch.no_grad():
286
output = model(**inputs_dict)
287
assert output is not None
288
289
model.set_attention_slice(2)
290
with torch.no_grad():
291
output = model(**inputs_dict)
292
assert output is not None
293
294
def test_model_sliceable_head_dim(self):
295
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
296
297
init_dict["attention_head_dim"] = (8, 16)
298
299
model = self.model_class(**init_dict)
300
301
def check_sliceable_dim_attr(module: torch.nn.Module):
302
if hasattr(module, "set_attention_slice"):
303
assert isinstance(module.sliceable_head_dim, int)
304
305
for child in module.children():
306
check_sliceable_dim_attr(child)
307
308
# retrieve number of attention layers
309
for module in model.children():
310
check_sliceable_dim_attr(module)
311
312
def test_special_attn_proc(self):
313
class AttnEasyProc(torch.nn.Module):
314
def __init__(self, num):
315
super().__init__()
316
self.weight = torch.nn.Parameter(torch.tensor(num))
317
self.is_run = False
318
self.number = 0
319
self.counter = 0
320
321
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
322
batch_size, sequence_length, _ = hidden_states.shape
323
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
324
325
query = attn.to_q(hidden_states)
326
327
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
328
key = attn.to_k(encoder_hidden_states)
329
value = attn.to_v(encoder_hidden_states)
330
331
query = attn.head_to_batch_dim(query)
332
key = attn.head_to_batch_dim(key)
333
value = attn.head_to_batch_dim(value)
334
335
attention_probs = attn.get_attention_scores(query, key, attention_mask)
336
hidden_states = torch.bmm(attention_probs, value)
337
hidden_states = attn.batch_to_head_dim(hidden_states)
338
339
# linear proj
340
hidden_states = attn.to_out[0](hidden_states)
341
# dropout
342
hidden_states = attn.to_out[1](hidden_states)
343
344
hidden_states += self.weight
345
346
self.is_run = True
347
self.counter += 1
348
self.number = number
349
350
return hidden_states
351
352
# enable deterministic behavior for gradient checkpointing
353
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
354
355
init_dict["attention_head_dim"] = (8, 16)
356
357
model = self.model_class(**init_dict)
358
model.to(torch_device)
359
360
processor = AttnEasyProc(5.0)
361
362
model.set_attn_processor(processor)
363
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
364
365
assert processor.counter == 12
366
assert processor.is_run
367
assert processor.number == 123
368
369
def test_lora_processors(self):
370
# enable deterministic behavior for gradient checkpointing
371
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
372
373
init_dict["attention_head_dim"] = (8, 16)
374
375
model = self.model_class(**init_dict)
376
model.to(torch_device)
377
378
with torch.no_grad():
379
sample1 = model(**inputs_dict).sample
380
381
lora_attn_procs = {}
382
for name in model.attn_processors.keys():
383
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
384
if name.startswith("mid_block"):
385
hidden_size = model.config.block_out_channels[-1]
386
elif name.startswith("up_blocks"):
387
block_id = int(name[len("up_blocks.")])
388
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
389
elif name.startswith("down_blocks"):
390
block_id = int(name[len("down_blocks.")])
391
hidden_size = model.config.block_out_channels[block_id]
392
393
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
394
395
# add 1 to weights to mock trained weights
396
with torch.no_grad():
397
lora_attn_procs[name].to_q_lora.up.weight += 1
398
lora_attn_procs[name].to_k_lora.up.weight += 1
399
lora_attn_procs[name].to_v_lora.up.weight += 1
400
lora_attn_procs[name].to_out_lora.up.weight += 1
401
402
# make sure we can set a list of attention processors
403
model.set_attn_processor(lora_attn_procs)
404
model.to(torch_device)
405
406
# test that attn processors can be set to itself
407
model.set_attn_processor(model.attn_processors)
408
409
with torch.no_grad():
410
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
411
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
412
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
413
414
assert (sample1 - sample2).abs().max() < 1e-4
415
assert (sample3 - sample4).abs().max() < 1e-4
416
417
# sample 2 and sample 3 should be different
418
assert (sample2 - sample3).abs().max() > 1e-4
419
420
def test_lora_save_load(self):
421
# enable deterministic behavior for gradient checkpointing
422
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
423
424
init_dict["attention_head_dim"] = (8, 16)
425
426
torch.manual_seed(0)
427
model = self.model_class(**init_dict)
428
model.to(torch_device)
429
430
with torch.no_grad():
431
old_sample = model(**inputs_dict).sample
432
433
lora_attn_procs = create_lora_layers(model)
434
model.set_attn_processor(lora_attn_procs)
435
436
with torch.no_grad():
437
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
438
439
with tempfile.TemporaryDirectory() as tmpdirname:
440
model.save_attn_procs(tmpdirname)
441
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
442
torch.manual_seed(0)
443
new_model = self.model_class(**init_dict)
444
new_model.to(torch_device)
445
new_model.load_attn_procs(tmpdirname)
446
447
with torch.no_grad():
448
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
449
450
assert (sample - new_sample).abs().max() < 1e-4
451
452
# LoRA and no LoRA should NOT be the same
453
assert (sample - old_sample).abs().max() > 1e-4
454
455
def test_lora_save_load_safetensors(self):
456
# enable deterministic behavior for gradient checkpointing
457
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
458
459
init_dict["attention_head_dim"] = (8, 16)
460
461
torch.manual_seed(0)
462
model = self.model_class(**init_dict)
463
model.to(torch_device)
464
465
with torch.no_grad():
466
old_sample = model(**inputs_dict).sample
467
468
lora_attn_procs = {}
469
for name in model.attn_processors.keys():
470
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
471
if name.startswith("mid_block"):
472
hidden_size = model.config.block_out_channels[-1]
473
elif name.startswith("up_blocks"):
474
block_id = int(name[len("up_blocks.")])
475
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
476
elif name.startswith("down_blocks"):
477
block_id = int(name[len("down_blocks.")])
478
hidden_size = model.config.block_out_channels[block_id]
479
480
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
481
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
482
483
# add 1 to weights to mock trained weights
484
with torch.no_grad():
485
lora_attn_procs[name].to_q_lora.up.weight += 1
486
lora_attn_procs[name].to_k_lora.up.weight += 1
487
lora_attn_procs[name].to_v_lora.up.weight += 1
488
lora_attn_procs[name].to_out_lora.up.weight += 1
489
490
model.set_attn_processor(lora_attn_procs)
491
492
with torch.no_grad():
493
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
494
495
with tempfile.TemporaryDirectory() as tmpdirname:
496
model.save_attn_procs(tmpdirname, safe_serialization=True)
497
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
498
torch.manual_seed(0)
499
new_model = self.model_class(**init_dict)
500
new_model.to(torch_device)
501
new_model.load_attn_procs(tmpdirname)
502
503
with torch.no_grad():
504
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
505
506
assert (sample - new_sample).abs().max() < 1e-4
507
508
# LoRA and no LoRA should NOT be the same
509
assert (sample - old_sample).abs().max() > 1e-4
510
511
def test_lora_save_safetensors_load_torch(self):
512
# enable deterministic behavior for gradient checkpointing
513
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
514
515
init_dict["attention_head_dim"] = (8, 16)
516
517
torch.manual_seed(0)
518
model = self.model_class(**init_dict)
519
model.to(torch_device)
520
521
lora_attn_procs = {}
522
for name in model.attn_processors.keys():
523
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
524
if name.startswith("mid_block"):
525
hidden_size = model.config.block_out_channels[-1]
526
elif name.startswith("up_blocks"):
527
block_id = int(name[len("up_blocks.")])
528
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
529
elif name.startswith("down_blocks"):
530
block_id = int(name[len("down_blocks.")])
531
hidden_size = model.config.block_out_channels[block_id]
532
533
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
534
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
535
536
model.set_attn_processor(lora_attn_procs)
537
# Saving as torch, properly reloads with directly filename
538
with tempfile.TemporaryDirectory() as tmpdirname:
539
model.save_attn_procs(tmpdirname)
540
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
541
torch.manual_seed(0)
542
new_model = self.model_class(**init_dict)
543
new_model.to(torch_device)
544
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
545
546
def test_lora_save_torch_force_load_safetensors_error(self):
547
# enable deterministic behavior for gradient checkpointing
548
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
549
550
init_dict["attention_head_dim"] = (8, 16)
551
552
torch.manual_seed(0)
553
model = self.model_class(**init_dict)
554
model.to(torch_device)
555
556
lora_attn_procs = {}
557
for name in model.attn_processors.keys():
558
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
559
if name.startswith("mid_block"):
560
hidden_size = model.config.block_out_channels[-1]
561
elif name.startswith("up_blocks"):
562
block_id = int(name[len("up_blocks.")])
563
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
564
elif name.startswith("down_blocks"):
565
block_id = int(name[len("down_blocks.")])
566
hidden_size = model.config.block_out_channels[block_id]
567
568
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
569
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
570
571
model.set_attn_processor(lora_attn_procs)
572
# Saving as torch, properly reloads with directly filename
573
with tempfile.TemporaryDirectory() as tmpdirname:
574
model.save_attn_procs(tmpdirname)
575
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
576
torch.manual_seed(0)
577
new_model = self.model_class(**init_dict)
578
new_model.to(torch_device)
579
with self.assertRaises(IOError) as e:
580
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
581
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
582
583
def test_lora_on_off(self):
584
# enable deterministic behavior for gradient checkpointing
585
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
586
587
init_dict["attention_head_dim"] = (8, 16)
588
589
torch.manual_seed(0)
590
model = self.model_class(**init_dict)
591
model.to(torch_device)
592
593
with torch.no_grad():
594
old_sample = model(**inputs_dict).sample
595
596
lora_attn_procs = create_lora_layers(model)
597
model.set_attn_processor(lora_attn_procs)
598
599
with torch.no_grad():
600
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
601
602
model.set_attn_processor(AttnProcessor())
603
604
with torch.no_grad():
605
new_sample = model(**inputs_dict).sample
606
607
assert (sample - new_sample).abs().max() < 1e-4
608
assert (sample - old_sample).abs().max() < 1e-4
609
610
@unittest.skipIf(
611
torch_device != "cuda" or not is_xformers_available(),
612
reason="XFormers attention is only available with CUDA and `xformers` installed",
613
)
614
def test_lora_xformers_on_off(self):
615
# enable deterministic behavior for gradient checkpointing
616
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
617
618
init_dict["attention_head_dim"] = (8, 16)
619
620
torch.manual_seed(0)
621
model = self.model_class(**init_dict)
622
model.to(torch_device)
623
lora_attn_procs = create_lora_layers(model)
624
model.set_attn_processor(lora_attn_procs)
625
626
# default
627
with torch.no_grad():
628
sample = model(**inputs_dict).sample
629
630
model.enable_xformers_memory_efficient_attention()
631
on_sample = model(**inputs_dict).sample
632
633
model.disable_xformers_memory_efficient_attention()
634
off_sample = model(**inputs_dict).sample
635
636
assert (sample - on_sample).abs().max() < 1e-4
637
assert (sample - off_sample).abs().max() < 1e-4
638
639
640
@slow
641
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
642
def get_file_format(self, seed, shape):
643
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
644
645
def tearDown(self):
646
# clean up the VRAM after each test
647
super().tearDown()
648
gc.collect()
649
torch.cuda.empty_cache()
650
651
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
652
dtype = torch.float16 if fp16 else torch.float32
653
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
654
return image
655
656
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
657
revision = "fp16" if fp16 else None
658
torch_dtype = torch.float16 if fp16 else torch.float32
659
660
model = UNet2DConditionModel.from_pretrained(
661
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
662
)
663
model.to(torch_device).eval()
664
665
return model
666
667
def test_set_attention_slice_auto(self):
668
torch.cuda.empty_cache()
669
torch.cuda.reset_max_memory_allocated()
670
torch.cuda.reset_peak_memory_stats()
671
672
unet = self.get_unet_model()
673
unet.set_attention_slice("auto")
674
675
latents = self.get_latents(33)
676
encoder_hidden_states = self.get_encoder_hidden_states(33)
677
timestep = 1
678
679
with torch.no_grad():
680
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
681
682
mem_bytes = torch.cuda.max_memory_allocated()
683
684
assert mem_bytes < 5 * 10**9
685
686
def test_set_attention_slice_max(self):
687
torch.cuda.empty_cache()
688
torch.cuda.reset_max_memory_allocated()
689
torch.cuda.reset_peak_memory_stats()
690
691
unet = self.get_unet_model()
692
unet.set_attention_slice("max")
693
694
latents = self.get_latents(33)
695
encoder_hidden_states = self.get_encoder_hidden_states(33)
696
timestep = 1
697
698
with torch.no_grad():
699
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
700
701
mem_bytes = torch.cuda.max_memory_allocated()
702
703
assert mem_bytes < 5 * 10**9
704
705
def test_set_attention_slice_int(self):
706
torch.cuda.empty_cache()
707
torch.cuda.reset_max_memory_allocated()
708
torch.cuda.reset_peak_memory_stats()
709
710
unet = self.get_unet_model()
711
unet.set_attention_slice(2)
712
713
latents = self.get_latents(33)
714
encoder_hidden_states = self.get_encoder_hidden_states(33)
715
timestep = 1
716
717
with torch.no_grad():
718
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
719
720
mem_bytes = torch.cuda.max_memory_allocated()
721
722
assert mem_bytes < 5 * 10**9
723
724
def test_set_attention_slice_list(self):
725
torch.cuda.empty_cache()
726
torch.cuda.reset_max_memory_allocated()
727
torch.cuda.reset_peak_memory_stats()
728
729
# there are 32 sliceable layers
730
slice_list = 16 * [2, 3]
731
unet = self.get_unet_model()
732
unet.set_attention_slice(slice_list)
733
734
latents = self.get_latents(33)
735
encoder_hidden_states = self.get_encoder_hidden_states(33)
736
timestep = 1
737
738
with torch.no_grad():
739
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
740
741
mem_bytes = torch.cuda.max_memory_allocated()
742
743
assert mem_bytes < 5 * 10**9
744
745
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
746
dtype = torch.float16 if fp16 else torch.float32
747
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
748
return hidden_states
749
750
@parameterized.expand(
751
[
752
# fmt: off
753
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
754
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
755
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
756
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
757
# fmt: on
758
]
759
)
760
@require_torch_gpu
761
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
762
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
763
latents = self.get_latents(seed)
764
encoder_hidden_states = self.get_encoder_hidden_states(seed)
765
766
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
767
768
with torch.no_grad():
769
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
770
771
assert sample.shape == latents.shape
772
773
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
774
expected_output_slice = torch.tensor(expected_slice)
775
776
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
777
778
@parameterized.expand(
779
[
780
# fmt: off
781
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
782
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
783
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
784
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
785
# fmt: on
786
]
787
)
788
@require_torch_gpu
789
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
790
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
791
latents = self.get_latents(seed, fp16=True)
792
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
793
794
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
795
796
with torch.no_grad():
797
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
798
799
assert sample.shape == latents.shape
800
801
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
802
expected_output_slice = torch.tensor(expected_slice)
803
804
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
805
806
@parameterized.expand(
807
[
808
# fmt: off
809
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
810
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
811
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
812
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
813
# fmt: on
814
]
815
)
816
@require_torch_gpu
817
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
818
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
819
latents = self.get_latents(seed)
820
encoder_hidden_states = self.get_encoder_hidden_states(seed)
821
822
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
823
824
with torch.no_grad():
825
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
826
827
assert sample.shape == latents.shape
828
829
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
830
expected_output_slice = torch.tensor(expected_slice)
831
832
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
833
834
@parameterized.expand(
835
[
836
# fmt: off
837
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
838
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
839
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
840
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
841
# fmt: on
842
]
843
)
844
@require_torch_gpu
845
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
846
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
847
latents = self.get_latents(seed, fp16=True)
848
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
849
850
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
851
852
with torch.no_grad():
853
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
854
855
assert sample.shape == latents.shape
856
857
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
858
expected_output_slice = torch.tensor(expected_slice)
859
860
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
861
862
@parameterized.expand(
863
[
864
# fmt: off
865
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
866
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
867
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
868
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
869
# fmt: on
870
]
871
)
872
@require_torch_gpu
873
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
874
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
875
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
876
encoder_hidden_states = self.get_encoder_hidden_states(seed)
877
878
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
879
880
with torch.no_grad():
881
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
882
883
assert sample.shape == (4, 4, 64, 64)
884
885
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
886
expected_output_slice = torch.tensor(expected_slice)
887
888
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
889
890
@parameterized.expand(
891
[
892
# fmt: off
893
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
894
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
895
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
896
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
897
# fmt: on
898
]
899
)
900
@require_torch_gpu
901
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
902
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
903
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
904
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
905
906
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
907
908
with torch.no_grad():
909
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
910
911
assert sample.shape == (4, 4, 64, 64)
912
913
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
914
expected_output_slice = torch.tensor(expected_slice)
915
916
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
917
918
@parameterized.expand(
919
[
920
# fmt: off
921
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
922
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
923
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
924
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
925
# fmt: on
926
]
927
)
928
@require_torch_gpu
929
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
930
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
931
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
932
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
933
934
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
935
936
with torch.no_grad():
937
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
938
939
assert sample.shape == latents.shape
940
941
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
942
expected_output_slice = torch.tensor(expected_slice)
943
944
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
945
946