Path: blob/main/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import random17import unittest1819import numpy as np20import torch21from PIL import Image22from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2324from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel25from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device26from diffusers.utils.testing_utils import require_torch_gpu272829torch.backends.cuda.matmul.allow_tf32 = False303132class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):33def tearDown(self):34# clean up the VRAM after each test35super().tearDown()36gc.collect()37torch.cuda.empty_cache()3839@property40def dummy_image(self):41batch_size = 142num_channels = 343sizes = (32, 32)4445image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)46return image4748@property49def dummy_cond_unet_upscale(self):50torch.manual_seed(0)51model = UNet2DConditionModel(52block_out_channels=(32, 32, 64),53layers_per_block=2,54sample_size=32,55in_channels=7,56out_channels=4,57down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),58up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),59cross_attention_dim=32,60# SD2-specific config below61attention_head_dim=8,62use_linear_projection=True,63only_cross_attention=(True, True, False),64num_class_embeds=100,65)66return model6768@property69def dummy_vae(self):70torch.manual_seed(0)71model = AutoencoderKL(72block_out_channels=[32, 32, 64],73in_channels=3,74out_channels=3,75down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],76up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],77latent_channels=4,78)79return model8081@property82def dummy_text_encoder(self):83torch.manual_seed(0)84config = CLIPTextConfig(85bos_token_id=0,86eos_token_id=2,87hidden_size=32,88intermediate_size=37,89layer_norm_eps=1e-05,90num_attention_heads=4,91num_hidden_layers=5,92pad_token_id=1,93vocab_size=1000,94# SD2-specific config below95hidden_act="gelu",96projection_dim=512,97)98return CLIPTextModel(config)99100def test_stable_diffusion_upscale(self):101device = "cpu" # ensure determinism for the device-dependent torch.Generator102unet = self.dummy_cond_unet_upscale103low_res_scheduler = DDPMScheduler()104scheduler = DDIMScheduler(prediction_type="v_prediction")105vae = self.dummy_vae106text_encoder = self.dummy_text_encoder107tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")108109image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]110low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))111112# make sure here that pndm scheduler skips prk113sd_pipe = StableDiffusionUpscalePipeline(114unet=unet,115low_res_scheduler=low_res_scheduler,116scheduler=scheduler,117vae=vae,118text_encoder=text_encoder,119tokenizer=tokenizer,120max_noise_level=350,121)122sd_pipe = sd_pipe.to(device)123sd_pipe.set_progress_bar_config(disable=None)124125prompt = "A painting of a squirrel eating a burger"126generator = torch.Generator(device=device).manual_seed(0)127output = sd_pipe(128[prompt],129image=low_res_image,130generator=generator,131guidance_scale=6.0,132noise_level=20,133num_inference_steps=2,134output_type="np",135)136137image = output.images138139generator = torch.Generator(device=device).manual_seed(0)140image_from_tuple = sd_pipe(141[prompt],142image=low_res_image,143generator=generator,144guidance_scale=6.0,145noise_level=20,146num_inference_steps=2,147output_type="np",148return_dict=False,149)[0]150151image_slice = image[0, -3:, -3:, -1]152image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]153154expected_height_width = low_res_image.size[0] * 4155assert image.shape == (1, expected_height_width, expected_height_width, 3)156expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606])157158assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2159assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2160161def test_stable_diffusion_upscale_batch(self):162device = "cpu" # ensure determinism for the device-dependent torch.Generator163unet = self.dummy_cond_unet_upscale164low_res_scheduler = DDPMScheduler()165scheduler = DDIMScheduler(prediction_type="v_prediction")166vae = self.dummy_vae167text_encoder = self.dummy_text_encoder168tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")169170image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]171low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))172173# make sure here that pndm scheduler skips prk174sd_pipe = StableDiffusionUpscalePipeline(175unet=unet,176low_res_scheduler=low_res_scheduler,177scheduler=scheduler,178vae=vae,179text_encoder=text_encoder,180tokenizer=tokenizer,181max_noise_level=350,182)183sd_pipe = sd_pipe.to(device)184sd_pipe.set_progress_bar_config(disable=None)185186prompt = "A painting of a squirrel eating a burger"187output = sd_pipe(1882 * [prompt],189image=2 * [low_res_image],190guidance_scale=6.0,191noise_level=20,192num_inference_steps=2,193output_type="np",194)195image = output.images196assert image.shape[0] == 2197198generator = torch.Generator(device=device).manual_seed(0)199output = sd_pipe(200[prompt],201image=low_res_image,202generator=generator,203num_images_per_prompt=2,204guidance_scale=6.0,205noise_level=20,206num_inference_steps=2,207output_type="np",208)209image = output.images210assert image.shape[0] == 2211212@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")213def test_stable_diffusion_upscale_fp16(self):214"""Test that stable diffusion upscale works with fp16"""215unet = self.dummy_cond_unet_upscale216low_res_scheduler = DDPMScheduler()217scheduler = DDIMScheduler(prediction_type="v_prediction")218vae = self.dummy_vae219text_encoder = self.dummy_text_encoder220tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")221222image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]223low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))224225# put models in fp16, except vae as it overflows in fp16226unet = unet.half()227text_encoder = text_encoder.half()228229# make sure here that pndm scheduler skips prk230sd_pipe = StableDiffusionUpscalePipeline(231unet=unet,232low_res_scheduler=low_res_scheduler,233scheduler=scheduler,234vae=vae,235text_encoder=text_encoder,236tokenizer=tokenizer,237max_noise_level=350,238)239sd_pipe = sd_pipe.to(torch_device)240sd_pipe.set_progress_bar_config(disable=None)241242prompt = "A painting of a squirrel eating a burger"243generator = torch.manual_seed(0)244image = sd_pipe(245[prompt],246image=low_res_image,247generator=generator,248num_inference_steps=2,249output_type="np",250).images251252expected_height_width = low_res_image.size[0] * 4253assert image.shape == (1, expected_height_width, expected_height_width, 3)254255256@slow257@require_torch_gpu258class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):259def tearDown(self):260# clean up the VRAM after each test261super().tearDown()262gc.collect()263torch.cuda.empty_cache()264265def test_stable_diffusion_upscale_pipeline(self):266image = load_image(267"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"268"/sd2-upscale/low_res_cat.png"269)270expected_image = load_numpy(271"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"272"/upsampled_cat.npy"273)274275model_id = "stabilityai/stable-diffusion-x4-upscaler"276pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)277pipe.to(torch_device)278pipe.set_progress_bar_config(disable=None)279pipe.enable_attention_slicing()280281prompt = "a cat sitting on a park bench"282283generator = torch.manual_seed(0)284output = pipe(285prompt=prompt,286image=image,287generator=generator,288output_type="np",289)290image = output.images[0]291292assert image.shape == (512, 512, 3)293assert np.abs(expected_image - image).max() < 1e-3294295def test_stable_diffusion_upscale_pipeline_fp16(self):296image = load_image(297"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"298"/sd2-upscale/low_res_cat.png"299)300expected_image = load_numpy(301"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"302"/upsampled_cat_fp16.npy"303)304305model_id = "stabilityai/stable-diffusion-x4-upscaler"306pipe = StableDiffusionUpscalePipeline.from_pretrained(307model_id,308torch_dtype=torch.float16,309)310pipe.to(torch_device)311pipe.set_progress_bar_config(disable=None)312pipe.enable_attention_slicing()313314prompt = "a cat sitting on a park bench"315316generator = torch.manual_seed(0)317output = pipe(318prompt=prompt,319image=image,320generator=generator,321output_type="np",322)323image = output.images[0]324325assert image.shape == (512, 512, 3)326assert np.abs(expected_image - image).max() < 5e-1327328def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):329torch.cuda.empty_cache()330torch.cuda.reset_max_memory_allocated()331torch.cuda.reset_peak_memory_stats()332333image = load_image(334"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"335"/sd2-upscale/low_res_cat.png"336)337338model_id = "stabilityai/stable-diffusion-x4-upscaler"339pipe = StableDiffusionUpscalePipeline.from_pretrained(340model_id,341torch_dtype=torch.float16,342)343pipe.to(torch_device)344pipe.set_progress_bar_config(disable=None)345pipe.enable_attention_slicing(1)346pipe.enable_sequential_cpu_offload()347348prompt = "a cat sitting on a park bench"349350generator = torch.manual_seed(0)351_ = pipe(352prompt=prompt,353image=image,354generator=generator,355num_inference_steps=5,356output_type="np",357)358359mem_bytes = torch.cuda.max_memory_allocated()360# make sure that less than 2.9 GB is allocated361assert mem_bytes < 2.9 * 10**9362363364