Path: blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
1440 views
# Copyright 2023 The HuggingFace Team. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import argparse15import os16import shutil17from pathlib import Path1819import onnx20import torch21from packaging import version22from torch.onnx import export2324from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline252627is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")282930def onnx_export(31model,32model_args: tuple,33output_path: Path,34ordered_input_names,35output_names,36dynamic_axes,37opset,38use_external_data_format=False,39):40output_path.parent.mkdir(parents=True, exist_ok=True)41# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,42# so we check the torch version for backwards compatibility43if is_torch_less_than_1_11:44export(45model,46model_args,47f=output_path.as_posix(),48input_names=ordered_input_names,49output_names=output_names,50dynamic_axes=dynamic_axes,51do_constant_folding=True,52use_external_data_format=use_external_data_format,53enable_onnx_checker=True,54opset_version=opset,55)56else:57export(58model,59model_args,60f=output_path.as_posix(),61input_names=ordered_input_names,62output_names=output_names,63dynamic_axes=dynamic_axes,64do_constant_folding=True,65opset_version=opset,66)676869@torch.no_grad()70def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):71dtype = torch.float16 if fp16 else torch.float3272if fp16 and torch.cuda.is_available():73device = "cuda"74elif fp16 and not torch.cuda.is_available():75raise ValueError("`float16` model export is only supported on GPUs with CUDA")76else:77device = "cpu"78pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)79output_path = Path(output_path)8081# TEXT ENCODER82num_tokens = pipeline.text_encoder.config.max_position_embeddings83text_hidden_size = pipeline.text_encoder.config.hidden_size84text_input = pipeline.tokenizer(85"A sample prompt",86padding="max_length",87max_length=pipeline.tokenizer.model_max_length,88truncation=True,89return_tensors="pt",90)91onnx_export(92pipeline.text_encoder,93# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files94model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),95output_path=output_path / "text_encoder" / "model.onnx",96ordered_input_names=["input_ids"],97output_names=["last_hidden_state", "pooler_output"],98dynamic_axes={99"input_ids": {0: "batch", 1: "sequence"},100},101opset=opset,102)103del pipeline.text_encoder104105# UNET106unet_in_channels = pipeline.unet.config.in_channels107unet_sample_size = pipeline.unet.config.sample_size108unet_path = output_path / "unet" / "model.onnx"109onnx_export(110pipeline.unet,111model_args=(112torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),113torch.randn(2).to(device=device, dtype=dtype),114torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),115False,116),117output_path=unet_path,118ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],119output_names=["out_sample"], # has to be different from "sample" for correct tracing120dynamic_axes={121"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},122"timestep": {0: "batch"},123"encoder_hidden_states": {0: "batch", 1: "sequence"},124},125opset=opset,126use_external_data_format=True, # UNet is > 2GB, so the weights need to be split127)128unet_model_path = str(unet_path.absolute().as_posix())129unet_dir = os.path.dirname(unet_model_path)130unet = onnx.load(unet_model_path)131# clean up existing tensor files132shutil.rmtree(unet_dir)133os.mkdir(unet_dir)134# collate external tensor files into one135onnx.save_model(136unet,137unet_model_path,138save_as_external_data=True,139all_tensors_to_one_file=True,140location="weights.pb",141convert_attribute=False,142)143del pipeline.unet144145# VAE ENCODER146vae_encoder = pipeline.vae147vae_in_channels = vae_encoder.config.in_channels148vae_sample_size = vae_encoder.config.sample_size149# need to get the raw tensor output (sample) from the encoder150vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()151onnx_export(152vae_encoder,153model_args=(154torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),155False,156),157output_path=output_path / "vae_encoder" / "model.onnx",158ordered_input_names=["sample", "return_dict"],159output_names=["latent_sample"],160dynamic_axes={161"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},162},163opset=opset,164)165166# VAE DECODER167vae_decoder = pipeline.vae168vae_latent_channels = vae_decoder.config.latent_channels169vae_out_channels = vae_decoder.config.out_channels170# forward only through the decoder part171vae_decoder.forward = vae_encoder.decode172onnx_export(173vae_decoder,174model_args=(175torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),176False,177),178output_path=output_path / "vae_decoder" / "model.onnx",179ordered_input_names=["latent_sample", "return_dict"],180output_names=["sample"],181dynamic_axes={182"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},183},184opset=opset,185)186del pipeline.vae187188# SAFETY CHECKER189if pipeline.safety_checker is not None:190safety_checker = pipeline.safety_checker191clip_num_channels = safety_checker.config.vision_config.num_channels192clip_image_size = safety_checker.config.vision_config.image_size193safety_checker.forward = safety_checker.forward_onnx194onnx_export(195pipeline.safety_checker,196model_args=(197torch.randn(1981,199clip_num_channels,200clip_image_size,201clip_image_size,202).to(device=device, dtype=dtype),203torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),204),205output_path=output_path / "safety_checker" / "model.onnx",206ordered_input_names=["clip_input", "images"],207output_names=["out_images", "has_nsfw_concepts"],208dynamic_axes={209"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},210"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},211},212opset=opset,213)214del pipeline.safety_checker215safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")216feature_extractor = pipeline.feature_extractor217else:218safety_checker = None219feature_extractor = None220221onnx_pipeline = OnnxStableDiffusionPipeline(222vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),223vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),224text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),225tokenizer=pipeline.tokenizer,226unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),227scheduler=pipeline.scheduler,228safety_checker=safety_checker,229feature_extractor=feature_extractor,230requires_safety_checker=safety_checker is not None,231)232233onnx_pipeline.save_pretrained(output_path)234print("ONNX pipeline saved to", output_path)235236del pipeline237del onnx_pipeline238_ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")239print("ONNX pipeline is loadable")240241242if __name__ == "__main__":243parser = argparse.ArgumentParser()244245parser.add_argument(246"--model_path",247type=str,248required=True,249help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",250)251252parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")253254parser.add_argument(255"--opset",256default=14,257type=int,258help="The version of the ONNX operator set to use.",259)260parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")261262args = parser.parse_args()263264convert_models(args.model_path, args.output_path, args.opset, args.fp16)265266267