Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/unclip/test_unclip.py
1450 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
22
23
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
24
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
25
from diffusers.utils import load_numpy, nightly, slow, torch_device
26
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
27
28
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
29
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
30
31
32
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33
pipeline_class = UnCLIPPipeline
34
params = TEXT_TO_IMAGE_PARAMS - {
35
"negative_prompt",
36
"height",
37
"width",
38
"negative_prompt_embeds",
39
"guidance_scale",
40
"prompt_embeds",
41
"cross_attention_kwargs",
42
}
43
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
44
required_optional_params = [
45
"generator",
46
"return_dict",
47
"prior_num_inference_steps",
48
"decoder_num_inference_steps",
49
"super_res_num_inference_steps",
50
]
51
test_xformers_attention = False
52
53
@property
54
def text_embedder_hidden_size(self):
55
return 32
56
57
@property
58
def time_input_dim(self):
59
return 32
60
61
@property
62
def block_out_channels_0(self):
63
return self.time_input_dim
64
65
@property
66
def time_embed_dim(self):
67
return self.time_input_dim * 4
68
69
@property
70
def cross_attention_dim(self):
71
return 100
72
73
@property
74
def dummy_tokenizer(self):
75
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
76
return tokenizer
77
78
@property
79
def dummy_text_encoder(self):
80
torch.manual_seed(0)
81
config = CLIPTextConfig(
82
bos_token_id=0,
83
eos_token_id=2,
84
hidden_size=self.text_embedder_hidden_size,
85
projection_dim=self.text_embedder_hidden_size,
86
intermediate_size=37,
87
layer_norm_eps=1e-05,
88
num_attention_heads=4,
89
num_hidden_layers=5,
90
pad_token_id=1,
91
vocab_size=1000,
92
)
93
return CLIPTextModelWithProjection(config)
94
95
@property
96
def dummy_prior(self):
97
torch.manual_seed(0)
98
99
model_kwargs = {
100
"num_attention_heads": 2,
101
"attention_head_dim": 12,
102
"embedding_dim": self.text_embedder_hidden_size,
103
"num_layers": 1,
104
}
105
106
model = PriorTransformer(**model_kwargs)
107
return model
108
109
@property
110
def dummy_text_proj(self):
111
torch.manual_seed(0)
112
113
model_kwargs = {
114
"clip_embeddings_dim": self.text_embedder_hidden_size,
115
"time_embed_dim": self.time_embed_dim,
116
"cross_attention_dim": self.cross_attention_dim,
117
}
118
119
model = UnCLIPTextProjModel(**model_kwargs)
120
return model
121
122
@property
123
def dummy_decoder(self):
124
torch.manual_seed(0)
125
126
model_kwargs = {
127
"sample_size": 32,
128
# RGB in channels
129
"in_channels": 3,
130
# Out channels is double in channels because predicts mean and variance
131
"out_channels": 6,
132
"down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
133
"up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
134
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
135
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
136
"layers_per_block": 1,
137
"cross_attention_dim": self.cross_attention_dim,
138
"attention_head_dim": 4,
139
"resnet_time_scale_shift": "scale_shift",
140
"class_embed_type": "identity",
141
}
142
143
model = UNet2DConditionModel(**model_kwargs)
144
return model
145
146
@property
147
def dummy_super_res_kwargs(self):
148
return {
149
"sample_size": 64,
150
"layers_per_block": 1,
151
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
152
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
153
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
154
"in_channels": 6,
155
"out_channels": 3,
156
}
157
158
@property
159
def dummy_super_res_first(self):
160
torch.manual_seed(0)
161
162
model = UNet2DModel(**self.dummy_super_res_kwargs)
163
return model
164
165
@property
166
def dummy_super_res_last(self):
167
# seeded differently to get different unet than `self.dummy_super_res_first`
168
torch.manual_seed(1)
169
170
model = UNet2DModel(**self.dummy_super_res_kwargs)
171
return model
172
173
def get_dummy_components(self):
174
prior = self.dummy_prior
175
decoder = self.dummy_decoder
176
text_proj = self.dummy_text_proj
177
text_encoder = self.dummy_text_encoder
178
tokenizer = self.dummy_tokenizer
179
super_res_first = self.dummy_super_res_first
180
super_res_last = self.dummy_super_res_last
181
182
prior_scheduler = UnCLIPScheduler(
183
variance_type="fixed_small_log",
184
prediction_type="sample",
185
num_train_timesteps=1000,
186
clip_sample_range=5.0,
187
)
188
189
decoder_scheduler = UnCLIPScheduler(
190
variance_type="learned_range",
191
prediction_type="epsilon",
192
num_train_timesteps=1000,
193
)
194
195
super_res_scheduler = UnCLIPScheduler(
196
variance_type="fixed_small_log",
197
prediction_type="epsilon",
198
num_train_timesteps=1000,
199
)
200
201
components = {
202
"prior": prior,
203
"decoder": decoder,
204
"text_proj": text_proj,
205
"text_encoder": text_encoder,
206
"tokenizer": tokenizer,
207
"super_res_first": super_res_first,
208
"super_res_last": super_res_last,
209
"prior_scheduler": prior_scheduler,
210
"decoder_scheduler": decoder_scheduler,
211
"super_res_scheduler": super_res_scheduler,
212
}
213
214
return components
215
216
def get_dummy_inputs(self, device, seed=0):
217
if str(device).startswith("mps"):
218
generator = torch.manual_seed(seed)
219
else:
220
generator = torch.Generator(device=device).manual_seed(seed)
221
inputs = {
222
"prompt": "horse",
223
"generator": generator,
224
"prior_num_inference_steps": 2,
225
"decoder_num_inference_steps": 2,
226
"super_res_num_inference_steps": 2,
227
"output_type": "numpy",
228
}
229
return inputs
230
231
def test_unclip(self):
232
device = "cpu"
233
234
components = self.get_dummy_components()
235
236
pipe = self.pipeline_class(**components)
237
pipe = pipe.to(device)
238
239
pipe.set_progress_bar_config(disable=None)
240
241
output = pipe(**self.get_dummy_inputs(device))
242
image = output.images
243
244
image_from_tuple = pipe(
245
**self.get_dummy_inputs(device),
246
return_dict=False,
247
)[0]
248
249
image_slice = image[0, -3:, -3:, -1]
250
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
251
252
assert image.shape == (1, 64, 64, 3)
253
254
expected_slice = np.array(
255
[
256
0.9997,
257
0.9988,
258
0.0028,
259
0.9997,
260
0.9984,
261
0.9965,
262
0.0029,
263
0.9986,
264
0.0025,
265
]
266
)
267
268
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
269
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
270
271
def test_unclip_passed_text_embed(self):
272
device = torch.device("cpu")
273
274
class DummyScheduler:
275
init_noise_sigma = 1
276
277
components = self.get_dummy_components()
278
279
pipe = self.pipeline_class(**components)
280
pipe = pipe.to(device)
281
282
prior = components["prior"]
283
decoder = components["decoder"]
284
super_res_first = components["super_res_first"]
285
tokenizer = components["tokenizer"]
286
text_encoder = components["text_encoder"]
287
288
generator = torch.Generator(device=device).manual_seed(0)
289
dtype = prior.dtype
290
batch_size = 1
291
292
shape = (batch_size, prior.config.embedding_dim)
293
prior_latents = pipe.prepare_latents(
294
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
295
)
296
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
297
decoder_latents = pipe.prepare_latents(
298
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
299
)
300
301
shape = (
302
batch_size,
303
super_res_first.in_channels // 2,
304
super_res_first.sample_size,
305
super_res_first.sample_size,
306
)
307
super_res_latents = pipe.prepare_latents(
308
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
309
)
310
311
pipe.set_progress_bar_config(disable=None)
312
313
prompt = "this is a prompt example"
314
315
generator = torch.Generator(device=device).manual_seed(0)
316
output = pipe(
317
[prompt],
318
generator=generator,
319
prior_num_inference_steps=2,
320
decoder_num_inference_steps=2,
321
super_res_num_inference_steps=2,
322
prior_latents=prior_latents,
323
decoder_latents=decoder_latents,
324
super_res_latents=super_res_latents,
325
output_type="np",
326
)
327
image = output.images
328
329
text_inputs = tokenizer(
330
prompt,
331
padding="max_length",
332
max_length=tokenizer.model_max_length,
333
return_tensors="pt",
334
)
335
text_model_output = text_encoder(text_inputs.input_ids)
336
text_attention_mask = text_inputs.attention_mask
337
338
generator = torch.Generator(device=device).manual_seed(0)
339
image_from_text = pipe(
340
generator=generator,
341
prior_num_inference_steps=2,
342
decoder_num_inference_steps=2,
343
super_res_num_inference_steps=2,
344
prior_latents=prior_latents,
345
decoder_latents=decoder_latents,
346
super_res_latents=super_res_latents,
347
text_model_output=text_model_output,
348
text_attention_mask=text_attention_mask,
349
output_type="np",
350
)[0]
351
352
# make sure passing text embeddings manually is identical
353
assert np.abs(image - image_from_text).max() < 1e-4
354
355
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
356
# because UnCLIP GPU undeterminism requires a looser check.
357
@skip_mps
358
def test_attention_slicing_forward_pass(self):
359
test_max_difference = torch_device == "cpu"
360
361
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
362
363
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
364
# because UnCLIP undeterminism requires a looser check.
365
@skip_mps
366
def test_inference_batch_single_identical(self):
367
test_max_difference = torch_device == "cpu"
368
relax_max_difference = True
369
additional_params_copy_to_batched_inputs = [
370
"prior_num_inference_steps",
371
"decoder_num_inference_steps",
372
"super_res_num_inference_steps",
373
]
374
375
self._test_inference_batch_single_identical(
376
test_max_difference=test_max_difference,
377
relax_max_difference=relax_max_difference,
378
additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
379
)
380
381
def test_inference_batch_consistent(self):
382
additional_params_copy_to_batched_inputs = [
383
"prior_num_inference_steps",
384
"decoder_num_inference_steps",
385
"super_res_num_inference_steps",
386
]
387
388
if torch_device == "mps":
389
# TODO: MPS errors with larger batch sizes
390
batch_sizes = [2, 3]
391
self._test_inference_batch_consistent(
392
batch_sizes=batch_sizes,
393
additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
394
)
395
else:
396
self._test_inference_batch_consistent(
397
additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
398
)
399
400
@skip_mps
401
def test_dict_tuple_outputs_equivalent(self):
402
return super().test_dict_tuple_outputs_equivalent()
403
404
@skip_mps
405
def test_save_load_local(self):
406
return super().test_save_load_local()
407
408
@skip_mps
409
def test_save_load_optional_components(self):
410
return super().test_save_load_optional_components()
411
412
413
@nightly
414
class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
415
def tearDown(self):
416
# clean up the VRAM after each test
417
super().tearDown()
418
gc.collect()
419
torch.cuda.empty_cache()
420
421
def test_unclip_karlo_cpu_fp32(self):
422
expected_image = load_numpy(
423
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
424
"/unclip/karlo_v1_alpha_horse_cpu.npy"
425
)
426
427
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
428
pipeline.set_progress_bar_config(disable=None)
429
430
generator = torch.manual_seed(0)
431
output = pipeline(
432
"horse",
433
num_images_per_prompt=1,
434
generator=generator,
435
output_type="np",
436
)
437
438
image = output.images[0]
439
440
assert image.shape == (256, 256, 3)
441
assert np.abs(expected_image - image).max() < 1e-1
442
443
444
@slow
445
@require_torch_gpu
446
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
447
def tearDown(self):
448
# clean up the VRAM after each test
449
super().tearDown()
450
gc.collect()
451
torch.cuda.empty_cache()
452
453
def test_unclip_karlo(self):
454
expected_image = load_numpy(
455
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
456
"/unclip/karlo_v1_alpha_horse_fp16.npy"
457
)
458
459
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
460
pipeline = pipeline.to(torch_device)
461
pipeline.set_progress_bar_config(disable=None)
462
463
generator = torch.Generator(device="cpu").manual_seed(0)
464
output = pipeline(
465
"horse",
466
generator=generator,
467
output_type="np",
468
)
469
470
image = output.images[0]
471
472
assert image.shape == (256, 256, 3)
473
474
assert_mean_pixel_difference(image, expected_image)
475
476
def test_unclip_pipeline_with_sequential_cpu_offloading(self):
477
torch.cuda.empty_cache()
478
torch.cuda.reset_max_memory_allocated()
479
torch.cuda.reset_peak_memory_stats()
480
481
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
482
pipe = pipe.to(torch_device)
483
pipe.set_progress_bar_config(disable=None)
484
pipe.enable_attention_slicing()
485
pipe.enable_sequential_cpu_offload()
486
487
_ = pipe(
488
"horse",
489
num_images_per_prompt=1,
490
prior_num_inference_steps=2,
491
decoder_num_inference_steps=2,
492
super_res_num_inference_steps=2,
493
output_type="np",
494
)
495
496
mem_bytes = torch.cuda.max_memory_allocated()
497
# make sure that less than 7 GB is allocated
498
assert mem_bytes < 7 * 10**9
499
500