Path: blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py
1451 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import gc16import unittest1718from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline19from diffusers.utils import is_flax_available, load_image, slow20from diffusers.utils.testing_utils import require_flax212223if is_flax_available():24import jax25import jax.numpy as jnp26from flax.jax_utils import replicate27from flax.training.common_utils import shard282930@slow31@require_flax32class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase):33def tearDown(self):34# clean up the VRAM after each test35super().tearDown()36gc.collect()3738def test_canny(self):39controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(40"lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat1641)42pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(43"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat1644)45params["controlnet"] = controlnet_params4647prompts = "bird"48num_samples = jax.device_count()49prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)5051canny_image = load_image(52"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"53)54processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)5556rng = jax.random.PRNGKey(0)57rng = jax.random.split(rng, jax.device_count())5859p_params = replicate(params)60prompt_ids = shard(prompt_ids)61processed_image = shard(processed_image)6263images = pipe(64prompt_ids=prompt_ids,65image=processed_image,66params=p_params,67prng_seed=rng,68num_inference_steps=50,69jit=True,70).images71assert images.shape == (jax.device_count(), 1, 768, 512, 3)7273images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])74image_slice = images[0, 253:256, 253:256, -1]7576output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))77expected_slice = jnp.array(78[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]79)80print(f"output_slice: {output_slice}")81assert jnp.abs(output_slice - expected_slice).max() < 1e-28283def test_pose(self):84controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(85"lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat1686)87pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(88"runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat1689)90params["controlnet"] = controlnet_params9192prompts = "Chef in the kitchen"93num_samples = jax.device_count()94prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)9596pose_image = load_image(97"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"98)99processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)100101rng = jax.random.PRNGKey(0)102rng = jax.random.split(rng, jax.device_count())103104p_params = replicate(params)105prompt_ids = shard(prompt_ids)106processed_image = shard(processed_image)107108images = pipe(109prompt_ids=prompt_ids,110image=processed_image,111params=p_params,112prng_seed=rng,113num_inference_steps=50,114jit=True,115).images116assert images.shape == (jax.device_count(), 1, 768, 512, 3)117118images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])119image_slice = images[0, 253:256, 253:256, -1]120121output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))122expected_slice = jnp.array(123[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]124)125print(f"output_slice: {output_slice}")126assert jnp.abs(output_slice - expected_slice).max() < 1e-2127128129