Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py
1451 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718import numpy as np19import torch2021from diffusers import StableDiffusionKDiffusionPipeline22from diffusers.utils import slow, torch_device23from diffusers.utils.testing_utils import require_torch_gpu242526torch.backends.cuda.matmul.allow_tf32 = False272829@slow30@require_torch_gpu31class StableDiffusionPipelineIntegrationTests(unittest.TestCase):32def tearDown(self):33# clean up the VRAM after each test34super().tearDown()35gc.collect()36torch.cuda.empty_cache()3738def test_stable_diffusion_1(self):39sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")40sd_pipe = sd_pipe.to(torch_device)41sd_pipe.set_progress_bar_config(disable=None)4243sd_pipe.set_scheduler("sample_euler")4445prompt = "A painting of a squirrel eating a burger"46generator = torch.manual_seed(0)47output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")4849image = output.images5051image_slice = image[0, -3:, -3:, -1]5253assert image.shape == (1, 512, 512, 3)54expected_slice = np.array([0.0447, 0.0492, 0.0468, 0.0408, 0.0383, 0.0408, 0.0354, 0.0380, 0.0339])5556assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-25758def test_stable_diffusion_2(self):59sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")60sd_pipe = sd_pipe.to(torch_device)61sd_pipe.set_progress_bar_config(disable=None)6263sd_pipe.set_scheduler("sample_euler")6465prompt = "A painting of a squirrel eating a burger"66generator = torch.manual_seed(0)67output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")6869image = output.images7071image_slice = image[0, -3:, -3:, -1]7273assert image.shape == (1, 512, 512, 3)74expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])7576assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1777879