Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/vq_diffusion/test_vq_diffusion.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
22
23
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
24
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
25
from diffusers.utils import load_numpy, slow, torch_device
26
from diffusers.utils.testing_utils import require_torch_gpu
27
28
29
torch.backends.cuda.matmul.allow_tf32 = False
30
31
32
class VQDiffusionPipelineFastTests(unittest.TestCase):
33
def tearDown(self):
34
# clean up the VRAM after each test
35
super().tearDown()
36
gc.collect()
37
torch.cuda.empty_cache()
38
39
@property
40
def num_embed(self):
41
return 12
42
43
@property
44
def num_embeds_ada_norm(self):
45
return 12
46
47
@property
48
def text_embedder_hidden_size(self):
49
return 32
50
51
@property
52
def dummy_vqvae(self):
53
torch.manual_seed(0)
54
model = VQModel(
55
block_out_channels=[32, 64],
56
in_channels=3,
57
out_channels=3,
58
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
59
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
60
latent_channels=3,
61
num_vq_embeddings=self.num_embed,
62
vq_embed_dim=3,
63
)
64
return model
65
66
@property
67
def dummy_tokenizer(self):
68
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
69
return tokenizer
70
71
@property
72
def dummy_text_encoder(self):
73
torch.manual_seed(0)
74
config = CLIPTextConfig(
75
bos_token_id=0,
76
eos_token_id=2,
77
hidden_size=self.text_embedder_hidden_size,
78
intermediate_size=37,
79
layer_norm_eps=1e-05,
80
num_attention_heads=4,
81
num_hidden_layers=5,
82
pad_token_id=1,
83
vocab_size=1000,
84
)
85
return CLIPTextModel(config)
86
87
@property
88
def dummy_transformer(self):
89
torch.manual_seed(0)
90
91
height = 12
92
width = 12
93
94
model_kwargs = {
95
"attention_bias": True,
96
"cross_attention_dim": 32,
97
"attention_head_dim": height * width,
98
"num_attention_heads": 1,
99
"num_vector_embeds": self.num_embed,
100
"num_embeds_ada_norm": self.num_embeds_ada_norm,
101
"norm_num_groups": 32,
102
"sample_size": width,
103
"activation_fn": "geglu-approximate",
104
}
105
106
model = Transformer2DModel(**model_kwargs)
107
return model
108
109
def test_vq_diffusion(self):
110
device = "cpu"
111
112
vqvae = self.dummy_vqvae
113
text_encoder = self.dummy_text_encoder
114
tokenizer = self.dummy_tokenizer
115
transformer = self.dummy_transformer
116
scheduler = VQDiffusionScheduler(self.num_embed)
117
learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False)
118
119
pipe = VQDiffusionPipeline(
120
vqvae=vqvae,
121
text_encoder=text_encoder,
122
tokenizer=tokenizer,
123
transformer=transformer,
124
scheduler=scheduler,
125
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
126
)
127
pipe = pipe.to(device)
128
pipe.set_progress_bar_config(disable=None)
129
130
prompt = "teddy bear playing in the pool"
131
132
generator = torch.Generator(device=device).manual_seed(0)
133
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
134
image = output.images
135
136
generator = torch.Generator(device=device).manual_seed(0)
137
image_from_tuple = pipe(
138
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
139
)[0]
140
141
image_slice = image[0, -3:, -3:, -1]
142
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
143
144
assert image.shape == (1, 24, 24, 3)
145
146
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
147
148
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
149
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
150
151
def test_vq_diffusion_classifier_free_sampling(self):
152
device = "cpu"
153
154
vqvae = self.dummy_vqvae
155
text_encoder = self.dummy_text_encoder
156
tokenizer = self.dummy_tokenizer
157
transformer = self.dummy_transformer
158
scheduler = VQDiffusionScheduler(self.num_embed)
159
learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(
160
learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length
161
)
162
163
pipe = VQDiffusionPipeline(
164
vqvae=vqvae,
165
text_encoder=text_encoder,
166
tokenizer=tokenizer,
167
transformer=transformer,
168
scheduler=scheduler,
169
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
170
)
171
pipe = pipe.to(device)
172
pipe.set_progress_bar_config(disable=None)
173
174
prompt = "teddy bear playing in the pool"
175
176
generator = torch.Generator(device=device).manual_seed(0)
177
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
178
image = output.images
179
180
generator = torch.Generator(device=device).manual_seed(0)
181
image_from_tuple = pipe(
182
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
183
)[0]
184
185
image_slice = image[0, -3:, -3:, -1]
186
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
187
188
assert image.shape == (1, 24, 24, 3)
189
190
expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912])
191
192
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
193
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
194
195
196
@slow
197
@require_torch_gpu
198
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
199
def tearDown(self):
200
# clean up the VRAM after each test
201
super().tearDown()
202
gc.collect()
203
torch.cuda.empty_cache()
204
205
def test_vq_diffusion_classifier_free_sampling(self):
206
expected_image = load_numpy(
207
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
208
"/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy"
209
)
210
211
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
212
pipeline = pipeline.to(torch_device)
213
pipeline.set_progress_bar_config(disable=None)
214
215
# requires GPU generator for gumbel softmax
216
# don't use GPU generator in tests though
217
generator = torch.Generator(device=torch_device).manual_seed(0)
218
output = pipeline(
219
"teddy bear playing in the pool",
220
num_images_per_prompt=1,
221
generator=generator,
222
output_type="np",
223
)
224
225
image = output.images[0]
226
227
assert image.shape == (256, 256, 3)
228
assert np.abs(expected_image - image).max() < 1e-2
229
230