Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
1448 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import gc
17
import random
18
import unittest
19
20
import numpy as np
21
import torch
22
from PIL import Image
23
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24
25
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
26
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
27
from diffusers.utils.testing_utils import require_torch_gpu
28
29
30
torch.backends.cuda.matmul.allow_tf32 = False
31
32
33
class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
34
def tearDown(self):
35
# clean up the VRAM after each test
36
super().tearDown()
37
gc.collect()
38
torch.cuda.empty_cache()
39
40
@property
41
def dummy_image(self):
42
batch_size = 1
43
num_channels = 3
44
sizes = (32, 32)
45
46
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
47
return image
48
49
@property
50
def dummy_cond_unet_upscale(self):
51
torch.manual_seed(0)
52
model = UNet2DConditionModel(
53
block_out_channels=(32, 32, 64),
54
layers_per_block=2,
55
sample_size=32,
56
in_channels=7,
57
out_channels=4,
58
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
59
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
60
cross_attention_dim=32,
61
# SD2-specific config below
62
attention_head_dim=8,
63
use_linear_projection=True,
64
only_cross_attention=(True, True, False),
65
num_class_embeds=100,
66
)
67
return model
68
69
@property
70
def dummy_vae(self):
71
torch.manual_seed(0)
72
model = AutoencoderKL(
73
block_out_channels=[32, 32, 64],
74
in_channels=3,
75
out_channels=3,
76
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
77
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
78
latent_channels=4,
79
)
80
return model
81
82
@property
83
def dummy_text_encoder(self):
84
torch.manual_seed(0)
85
config = CLIPTextConfig(
86
bos_token_id=0,
87
eos_token_id=2,
88
hidden_size=32,
89
intermediate_size=37,
90
layer_norm_eps=1e-05,
91
num_attention_heads=4,
92
num_hidden_layers=5,
93
pad_token_id=1,
94
vocab_size=1000,
95
# SD2-specific config below
96
hidden_act="gelu",
97
projection_dim=512,
98
)
99
return CLIPTextModel(config)
100
101
def test_stable_diffusion_upscale(self):
102
device = "cpu" # ensure determinism for the device-dependent torch.Generator
103
unet = self.dummy_cond_unet_upscale
104
low_res_scheduler = DDPMScheduler()
105
scheduler = DDIMScheduler(prediction_type="v_prediction")
106
vae = self.dummy_vae
107
text_encoder = self.dummy_text_encoder
108
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
109
110
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
111
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
112
113
# make sure here that pndm scheduler skips prk
114
sd_pipe = StableDiffusionUpscalePipeline(
115
unet=unet,
116
low_res_scheduler=low_res_scheduler,
117
scheduler=scheduler,
118
vae=vae,
119
text_encoder=text_encoder,
120
tokenizer=tokenizer,
121
max_noise_level=350,
122
)
123
sd_pipe = sd_pipe.to(device)
124
sd_pipe.set_progress_bar_config(disable=None)
125
126
prompt = "A painting of a squirrel eating a burger"
127
generator = torch.Generator(device=device).manual_seed(0)
128
output = sd_pipe(
129
[prompt],
130
image=low_res_image,
131
generator=generator,
132
guidance_scale=6.0,
133
noise_level=20,
134
num_inference_steps=2,
135
output_type="np",
136
)
137
138
image = output.images
139
140
generator = torch.Generator(device=device).manual_seed(0)
141
image_from_tuple = sd_pipe(
142
[prompt],
143
image=low_res_image,
144
generator=generator,
145
guidance_scale=6.0,
146
noise_level=20,
147
num_inference_steps=2,
148
output_type="np",
149
return_dict=False,
150
)[0]
151
152
image_slice = image[0, -3:, -3:, -1]
153
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
154
155
expected_height_width = low_res_image.size[0] * 4
156
assert image.shape == (1, expected_height_width, expected_height_width, 3)
157
expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606])
158
159
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
160
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
161
162
def test_stable_diffusion_upscale_batch(self):
163
device = "cpu" # ensure determinism for the device-dependent torch.Generator
164
unet = self.dummy_cond_unet_upscale
165
low_res_scheduler = DDPMScheduler()
166
scheduler = DDIMScheduler(prediction_type="v_prediction")
167
vae = self.dummy_vae
168
text_encoder = self.dummy_text_encoder
169
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
170
171
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
172
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
173
174
# make sure here that pndm scheduler skips prk
175
sd_pipe = StableDiffusionUpscalePipeline(
176
unet=unet,
177
low_res_scheduler=low_res_scheduler,
178
scheduler=scheduler,
179
vae=vae,
180
text_encoder=text_encoder,
181
tokenizer=tokenizer,
182
max_noise_level=350,
183
)
184
sd_pipe = sd_pipe.to(device)
185
sd_pipe.set_progress_bar_config(disable=None)
186
187
prompt = "A painting of a squirrel eating a burger"
188
output = sd_pipe(
189
2 * [prompt],
190
image=2 * [low_res_image],
191
guidance_scale=6.0,
192
noise_level=20,
193
num_inference_steps=2,
194
output_type="np",
195
)
196
image = output.images
197
assert image.shape[0] == 2
198
199
generator = torch.Generator(device=device).manual_seed(0)
200
output = sd_pipe(
201
[prompt],
202
image=low_res_image,
203
generator=generator,
204
num_images_per_prompt=2,
205
guidance_scale=6.0,
206
noise_level=20,
207
num_inference_steps=2,
208
output_type="np",
209
)
210
image = output.images
211
assert image.shape[0] == 2
212
213
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
214
def test_stable_diffusion_upscale_fp16(self):
215
"""Test that stable diffusion upscale works with fp16"""
216
unet = self.dummy_cond_unet_upscale
217
low_res_scheduler = DDPMScheduler()
218
scheduler = DDIMScheduler(prediction_type="v_prediction")
219
vae = self.dummy_vae
220
text_encoder = self.dummy_text_encoder
221
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
222
223
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
224
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
225
226
# put models in fp16, except vae as it overflows in fp16
227
unet = unet.half()
228
text_encoder = text_encoder.half()
229
230
# make sure here that pndm scheduler skips prk
231
sd_pipe = StableDiffusionUpscalePipeline(
232
unet=unet,
233
low_res_scheduler=low_res_scheduler,
234
scheduler=scheduler,
235
vae=vae,
236
text_encoder=text_encoder,
237
tokenizer=tokenizer,
238
max_noise_level=350,
239
)
240
sd_pipe = sd_pipe.to(torch_device)
241
sd_pipe.set_progress_bar_config(disable=None)
242
243
prompt = "A painting of a squirrel eating a burger"
244
generator = torch.manual_seed(0)
245
image = sd_pipe(
246
[prompt],
247
image=low_res_image,
248
generator=generator,
249
num_inference_steps=2,
250
output_type="np",
251
).images
252
253
expected_height_width = low_res_image.size[0] * 4
254
assert image.shape == (1, expected_height_width, expected_height_width, 3)
255
256
257
@slow
258
@require_torch_gpu
259
class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
260
def tearDown(self):
261
# clean up the VRAM after each test
262
super().tearDown()
263
gc.collect()
264
torch.cuda.empty_cache()
265
266
def test_stable_diffusion_upscale_pipeline(self):
267
image = load_image(
268
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
269
"/sd2-upscale/low_res_cat.png"
270
)
271
expected_image = load_numpy(
272
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
273
"/upsampled_cat.npy"
274
)
275
276
model_id = "stabilityai/stable-diffusion-x4-upscaler"
277
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
278
pipe.to(torch_device)
279
pipe.set_progress_bar_config(disable=None)
280
pipe.enable_attention_slicing()
281
282
prompt = "a cat sitting on a park bench"
283
284
generator = torch.manual_seed(0)
285
output = pipe(
286
prompt=prompt,
287
image=image,
288
generator=generator,
289
output_type="np",
290
)
291
image = output.images[0]
292
293
assert image.shape == (512, 512, 3)
294
assert np.abs(expected_image - image).max() < 1e-3
295
296
def test_stable_diffusion_upscale_pipeline_fp16(self):
297
image = load_image(
298
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
299
"/sd2-upscale/low_res_cat.png"
300
)
301
expected_image = load_numpy(
302
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
303
"/upsampled_cat_fp16.npy"
304
)
305
306
model_id = "stabilityai/stable-diffusion-x4-upscaler"
307
pipe = StableDiffusionUpscalePipeline.from_pretrained(
308
model_id,
309
torch_dtype=torch.float16,
310
)
311
pipe.to(torch_device)
312
pipe.set_progress_bar_config(disable=None)
313
pipe.enable_attention_slicing()
314
315
prompt = "a cat sitting on a park bench"
316
317
generator = torch.manual_seed(0)
318
output = pipe(
319
prompt=prompt,
320
image=image,
321
generator=generator,
322
output_type="np",
323
)
324
image = output.images[0]
325
326
assert image.shape == (512, 512, 3)
327
assert np.abs(expected_image - image).max() < 5e-1
328
329
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
330
torch.cuda.empty_cache()
331
torch.cuda.reset_max_memory_allocated()
332
torch.cuda.reset_peak_memory_stats()
333
334
image = load_image(
335
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
336
"/sd2-upscale/low_res_cat.png"
337
)
338
339
model_id = "stabilityai/stable-diffusion-x4-upscaler"
340
pipe = StableDiffusionUpscalePipeline.from_pretrained(
341
model_id,
342
torch_dtype=torch.float16,
343
)
344
pipe.to(torch_device)
345
pipe.set_progress_bar_config(disable=None)
346
pipe.enable_attention_slicing(1)
347
pipe.enable_sequential_cpu_offload()
348
349
prompt = "a cat sitting on a park bench"
350
351
generator = torch.manual_seed(0)
352
_ = pipe(
353
prompt=prompt,
354
image=image,
355
generator=generator,
356
num_inference_steps=5,
357
output_type="np",
358
)
359
360
mem_bytes = torch.cuda.max_memory_allocated()
361
# make sure that less than 2.9 GB is allocated
362
assert mem_bytes < 2.9 * 10**9
363
364