Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch20from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer2122from diffusers import (23AutoencoderKL,24DDIMScheduler,25StableDiffusionSAGPipeline,26UNet2DConditionModel,27)28from diffusers.utils import slow, torch_device29from diffusers.utils.testing_utils import require_torch_gpu3031from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS32from ...test_pipelines_common import PipelineTesterMixin333435torch.backends.cuda.matmul.allow_tf32 = False363738class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):39pipeline_class = StableDiffusionSAGPipeline40params = TEXT_TO_IMAGE_PARAMS41batch_params = TEXT_TO_IMAGE_BATCH_PARAMS42test_cpu_offload = False4344def get_dummy_components(self):45torch.manual_seed(0)46unet = UNet2DConditionModel(47block_out_channels=(32, 64),48layers_per_block=2,49sample_size=32,50in_channels=4,51out_channels=4,52down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),53up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),54cross_attention_dim=32,55)56scheduler = DDIMScheduler(57beta_start=0.00085,58beta_end=0.012,59beta_schedule="scaled_linear",60clip_sample=False,61set_alpha_to_one=False,62)63torch.manual_seed(0)64vae = AutoencoderKL(65block_out_channels=[32, 64],66in_channels=3,67out_channels=3,68down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],69up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],70latent_channels=4,71)72torch.manual_seed(0)73text_encoder_config = CLIPTextConfig(74bos_token_id=0,75eos_token_id=2,76hidden_size=32,77intermediate_size=37,78layer_norm_eps=1e-05,79num_attention_heads=4,80num_hidden_layers=5,81pad_token_id=1,82vocab_size=1000,83)84text_encoder = CLIPTextModel(text_encoder_config)85tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")8687components = {88"unet": unet,89"scheduler": scheduler,90"vae": vae,91"text_encoder": text_encoder,92"tokenizer": tokenizer,93"safety_checker": None,94"feature_extractor": None,95}96return components9798def get_dummy_inputs(self, device, seed=0):99if str(device).startswith("mps"):100generator = torch.manual_seed(seed)101else:102generator = torch.Generator(device=device).manual_seed(seed)103inputs = {104"prompt": ".",105"generator": generator,106"num_inference_steps": 2,107"guidance_scale": 1.0,108"sag_scale": 1.0,109"output_type": "numpy",110}111return inputs112113114@slow115@require_torch_gpu116class StableDiffusionPipelineIntegrationTests(unittest.TestCase):117def tearDown(self):118# clean up the VRAM after each test119super().tearDown()120gc.collect()121torch.cuda.empty_cache()122123def test_stable_diffusion_1(self):124sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")125sag_pipe = sag_pipe.to(torch_device)126sag_pipe.set_progress_bar_config(disable=None)127128prompt = "."129generator = torch.manual_seed(0)130output = sag_pipe(131[prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"132)133134image = output.images135136image_slice = image[0, -3:, -3:, -1]137138assert image.shape == (1, 512, 512, 3)139expected_slice = np.array([0.1568, 0.1738, 0.1695, 0.1693, 0.1507, 0.1705, 0.1547, 0.1751, 0.1949])140141assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2142143def test_stable_diffusion_2(self):144sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")145sag_pipe = sag_pipe.to(torch_device)146sag_pipe.set_progress_bar_config(disable=None)147148prompt = "."149generator = torch.manual_seed(0)150output = sag_pipe(151[prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"152)153154image = output.images155156image_slice = image[0, -3:, -3:, -1]157158assert image.shape == (1, 512, 512, 3)159expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])160161assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2162163def test_stable_diffusion_2_non_square(self):164sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")165sag_pipe = sag_pipe.to(torch_device)166sag_pipe.set_progress_bar_config(disable=None)167168prompt = "."169generator = torch.manual_seed(0)170output = sag_pipe(171[prompt],172width=768,173height=512,174generator=generator,175guidance_scale=7.5,176sag_scale=1.0,177num_inference_steps=20,178output_type="np",179)180181image = output.images182183assert image.shape == (1, 512, 768, 3)184185186