Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/altdiffusion/test_alt_diffusion.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import unittest
18
19
import numpy as np
20
import torch
21
from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer
22
23
from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel
24
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
25
RobertaSeriesConfig,
26
RobertaSeriesModelWithTransformation,
27
)
28
from diffusers.utils import slow, torch_device
29
from diffusers.utils.testing_utils import require_torch_gpu
30
31
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
32
from ...test_pipelines_common import PipelineTesterMixin
33
34
35
torch.backends.cuda.matmul.allow_tf32 = False
36
37
38
class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39
pipeline_class = AltDiffusionPipeline
40
params = TEXT_TO_IMAGE_PARAMS
41
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
42
43
def get_dummy_components(self):
44
torch.manual_seed(0)
45
unet = UNet2DConditionModel(
46
block_out_channels=(32, 64),
47
layers_per_block=2,
48
sample_size=32,
49
in_channels=4,
50
out_channels=4,
51
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
52
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
53
cross_attention_dim=32,
54
)
55
scheduler = DDIMScheduler(
56
beta_start=0.00085,
57
beta_end=0.012,
58
beta_schedule="scaled_linear",
59
clip_sample=False,
60
set_alpha_to_one=False,
61
)
62
torch.manual_seed(0)
63
vae = AutoencoderKL(
64
block_out_channels=[32, 64],
65
in_channels=3,
66
out_channels=3,
67
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
68
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
69
latent_channels=4,
70
)
71
72
# TODO: address the non-deterministic text encoder (fails for save-load tests)
73
# torch.manual_seed(0)
74
# text_encoder_config = RobertaSeriesConfig(
75
# hidden_size=32,
76
# project_dim=32,
77
# intermediate_size=37,
78
# layer_norm_eps=1e-05,
79
# num_attention_heads=4,
80
# num_hidden_layers=5,
81
# vocab_size=5002,
82
# )
83
# text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
84
85
torch.manual_seed(0)
86
text_encoder_config = CLIPTextConfig(
87
bos_token_id=0,
88
eos_token_id=2,
89
hidden_size=32,
90
projection_dim=32,
91
intermediate_size=37,
92
layer_norm_eps=1e-05,
93
num_attention_heads=4,
94
num_hidden_layers=5,
95
pad_token_id=1,
96
vocab_size=5002,
97
)
98
text_encoder = CLIPTextModel(text_encoder_config)
99
100
tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
101
tokenizer.model_max_length = 77
102
103
components = {
104
"unet": unet,
105
"scheduler": scheduler,
106
"vae": vae,
107
"text_encoder": text_encoder,
108
"tokenizer": tokenizer,
109
"safety_checker": None,
110
"feature_extractor": None,
111
}
112
return components
113
114
def get_dummy_inputs(self, device, seed=0):
115
if str(device).startswith("mps"):
116
generator = torch.manual_seed(seed)
117
else:
118
generator = torch.Generator(device=device).manual_seed(seed)
119
inputs = {
120
"prompt": "A painting of a squirrel eating a burger",
121
"generator": generator,
122
"num_inference_steps": 2,
123
"guidance_scale": 6.0,
124
"output_type": "numpy",
125
}
126
return inputs
127
128
def test_alt_diffusion_ddim(self):
129
device = "cpu" # ensure determinism for the device-dependent torch.Generator
130
131
components = self.get_dummy_components()
132
torch.manual_seed(0)
133
text_encoder_config = RobertaSeriesConfig(
134
hidden_size=32,
135
project_dim=32,
136
intermediate_size=37,
137
layer_norm_eps=1e-05,
138
num_attention_heads=4,
139
num_hidden_layers=5,
140
vocab_size=5002,
141
)
142
# TODO: remove after fixing the non-deterministic text encoder
143
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
144
components["text_encoder"] = text_encoder
145
146
alt_pipe = AltDiffusionPipeline(**components)
147
alt_pipe = alt_pipe.to(device)
148
alt_pipe.set_progress_bar_config(disable=None)
149
150
inputs = self.get_dummy_inputs(device)
151
inputs["prompt"] = "A photo of an astronaut"
152
output = alt_pipe(**inputs)
153
image = output.images
154
image_slice = image[0, -3:, -3:, -1]
155
156
assert image.shape == (1, 64, 64, 3)
157
expected_slice = np.array(
158
[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
159
)
160
161
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
162
163
def test_alt_diffusion_pndm(self):
164
device = "cpu" # ensure determinism for the device-dependent torch.Generator
165
166
components = self.get_dummy_components()
167
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
168
torch.manual_seed(0)
169
text_encoder_config = RobertaSeriesConfig(
170
hidden_size=32,
171
project_dim=32,
172
intermediate_size=37,
173
layer_norm_eps=1e-05,
174
num_attention_heads=4,
175
num_hidden_layers=5,
176
vocab_size=5002,
177
)
178
# TODO: remove after fixing the non-deterministic text encoder
179
text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
180
components["text_encoder"] = text_encoder
181
alt_pipe = AltDiffusionPipeline(**components)
182
alt_pipe = alt_pipe.to(device)
183
alt_pipe.set_progress_bar_config(disable=None)
184
185
inputs = self.get_dummy_inputs(device)
186
output = alt_pipe(**inputs)
187
image = output.images
188
image_slice = image[0, -3:, -3:, -1]
189
190
assert image.shape == (1, 64, 64, 3)
191
expected_slice = np.array(
192
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
193
)
194
195
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
196
197
198
@slow
199
@require_torch_gpu
200
class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
201
def tearDown(self):
202
# clean up the VRAM after each test
203
super().tearDown()
204
gc.collect()
205
torch.cuda.empty_cache()
206
207
def test_alt_diffusion(self):
208
# make sure here that pndm scheduler skips prk
209
alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)
210
alt_pipe = alt_pipe.to(torch_device)
211
alt_pipe.set_progress_bar_config(disable=None)
212
213
prompt = "A painting of a squirrel eating a burger"
214
generator = torch.manual_seed(0)
215
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
216
217
image = output.images
218
219
image_slice = image[0, -3:, -3:, -1]
220
221
assert image.shape == (1, 512, 512, 3)
222
expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])
223
224
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
225
226
def test_alt_diffusion_fast_ddim(self):
227
scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
228
229
alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)
230
alt_pipe = alt_pipe.to(torch_device)
231
alt_pipe.set_progress_bar_config(disable=None)
232
233
prompt = "A painting of a squirrel eating a burger"
234
generator = torch.manual_seed(0)
235
236
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
237
image = output.images
238
239
image_slice = image[0, -3:, -3:, -1]
240
241
assert image.shape == (1, 512, 512, 3)
242
expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
243
244
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
245
246