Path: blob/main/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py
1448 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import tempfile17import unittest1819import numpy as np20import torch2122from diffusers import VersatileDiffusionDualGuidedPipeline23from diffusers.utils.testing_utils import load_image, nightly, require_torch_gpu, torch_device242526torch.backends.cuda.matmul.allow_tf32 = False272829@nightly30@require_torch_gpu31class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):32def tearDown(self):33# clean up the VRAM after each test34super().tearDown()35gc.collect()36torch.cuda.empty_cache()3738def test_remove_unused_weights_save_load(self):39pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")40# remove text_unet41pipe.remove_unused_weights()42pipe.to(torch_device)43pipe.set_progress_bar_config(disable=None)4445second_prompt = load_image(46"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"47)4849generator = torch.manual_seed(0)50image = pipe(51prompt="first prompt",52image=second_prompt,53text_to_image_strength=0.75,54generator=generator,55guidance_scale=7.5,56num_inference_steps=2,57output_type="numpy",58).images5960with tempfile.TemporaryDirectory() as tmpdirname:61pipe.save_pretrained(tmpdirname)62pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname)6364pipe.to(torch_device)65pipe.set_progress_bar_config(disable=None)6667generator = generator.manual_seed(0)68new_image = pipe(69prompt="first prompt",70image=second_prompt,71text_to_image_strength=0.75,72generator=generator,73guidance_scale=7.5,74num_inference_steps=2,75output_type="numpy",76).images7778assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"7980def test_inference_dual_guided(self):81pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")82pipe.remove_unused_weights()83pipe.to(torch_device)84pipe.set_progress_bar_config(disable=None)8586first_prompt = "cyberpunk 2077"87second_prompt = load_image(88"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"89)90generator = torch.manual_seed(0)91image = pipe(92prompt=first_prompt,93image=second_prompt,94text_to_image_strength=0.75,95generator=generator,96guidance_scale=7.5,97num_inference_steps=50,98output_type="numpy",99).images100101image_slice = image[0, 253:256, 253:256, -1]102103assert image.shape == (1, 512, 512, 3)104expected_slice = np.array([0.0787, 0.0849, 0.0826, 0.0812, 0.0807, 0.0795, 0.0818, 0.0798, 0.0779])105106assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2107108109