Path: blob/main/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
1451 views
# coding=utf-81# Copyright 2023 HuggingFace Inc.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.1415import unittest1617import numpy as np1819from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline20from diffusers.utils.testing_utils import (21is_onnx_available,22load_image,23nightly,24require_onnxruntime,25require_torch_gpu,26)2728from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin293031if is_onnx_available():32import onnxruntime as ort333435class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):36# FIXME: add fast tests37pass383940@nightly41@require_onnxruntime42@require_torch_gpu43class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):44@property45def gpu_provider(self):46return (47"CUDAExecutionProvider",48{49"gpu_mem_limit": "15000000000", # 15GB50"arena_extend_strategy": "kSameAsRequested",51},52)5354@property55def gpu_options(self):56options = ort.SessionOptions()57options.enable_mem_pattern = False58return options5960def test_inference_default_pndm(self):61init_image = load_image(62"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"63"/in_paint/overture-creations-5sI6fQgYIuo.png"64)65mask_image = load_image(66"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"67"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"68)69pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(70"runwayml/stable-diffusion-inpainting",71revision="onnx",72safety_checker=None,73feature_extractor=None,74provider=self.gpu_provider,75sess_options=self.gpu_options,76)77pipe.set_progress_bar_config(disable=None)7879prompt = "A red cat sitting on a park bench"8081generator = np.random.RandomState(0)82output = pipe(83prompt=prompt,84image=init_image,85mask_image=mask_image,86guidance_scale=7.5,87num_inference_steps=10,88generator=generator,89output_type="np",90)91images = output.images92image_slice = images[0, 255:258, 255:258, -1]9394assert images.shape == (1, 512, 512, 3)95expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464])9697assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-39899def test_inference_k_lms(self):100init_image = load_image(101"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"102"/in_paint/overture-creations-5sI6fQgYIuo.png"103)104mask_image = load_image(105"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"106"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"107)108lms_scheduler = LMSDiscreteScheduler.from_pretrained(109"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"110)111pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(112"runwayml/stable-diffusion-inpainting",113revision="onnx",114scheduler=lms_scheduler,115safety_checker=None,116feature_extractor=None,117provider=self.gpu_provider,118sess_options=self.gpu_options,119)120pipe.set_progress_bar_config(disable=None)121122prompt = "A red cat sitting on a park bench"123124generator = np.random.RandomState(0)125output = pipe(126prompt=prompt,127image=init_image,128mask_image=mask_image,129guidance_scale=7.5,130num_inference_steps=20,131generator=generator,132output_type="np",133)134images = output.images135image_slice = images[0, 255:258, 255:258, -1]136137assert images.shape == (1, 512, 512, 3)138expected_slice = np.array([0.0086, 0.0077, 0.0083, 0.0093, 0.0107, 0.0139, 0.0094, 0.0097, 0.0125])139140assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3141142143