Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_vae_diff_to_onnx.py
1440 views
1
# Copyright 2022 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import argparse
16
from pathlib import Path
17
18
import torch
19
from packaging import version
20
from torch.onnx import export
21
22
from diffusers import AutoencoderKL
23
24
25
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
26
27
28
def onnx_export(
29
model,
30
model_args: tuple,
31
output_path: Path,
32
ordered_input_names,
33
output_names,
34
dynamic_axes,
35
opset,
36
use_external_data_format=False,
37
):
38
output_path.parent.mkdir(parents=True, exist_ok=True)
39
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
40
# so we check the torch version for backwards compatibility
41
if is_torch_less_than_1_11:
42
export(
43
model,
44
model_args,
45
f=output_path.as_posix(),
46
input_names=ordered_input_names,
47
output_names=output_names,
48
dynamic_axes=dynamic_axes,
49
do_constant_folding=True,
50
use_external_data_format=use_external_data_format,
51
enable_onnx_checker=True,
52
opset_version=opset,
53
)
54
else:
55
export(
56
model,
57
model_args,
58
f=output_path.as_posix(),
59
input_names=ordered_input_names,
60
output_names=output_names,
61
dynamic_axes=dynamic_axes,
62
do_constant_folding=True,
63
opset_version=opset,
64
)
65
66
67
@torch.no_grad()
68
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
69
dtype = torch.float16 if fp16 else torch.float32
70
if fp16 and torch.cuda.is_available():
71
device = "cuda"
72
elif fp16 and not torch.cuda.is_available():
73
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
74
else:
75
device = "cpu"
76
output_path = Path(output_path)
77
78
# VAE DECODER
79
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
80
vae_latent_channels = vae_decoder.config.latent_channels
81
# forward only through the decoder part
82
vae_decoder.forward = vae_decoder.decode
83
onnx_export(
84
vae_decoder,
85
model_args=(
86
torch.randn(1, vae_latent_channels, 25, 25).to(device=device, dtype=dtype),
87
False,
88
),
89
output_path=output_path / "vae_decoder" / "model.onnx",
90
ordered_input_names=["latent_sample", "return_dict"],
91
output_names=["sample"],
92
dynamic_axes={
93
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
94
},
95
opset=opset,
96
)
97
del vae_decoder
98
99
100
if __name__ == "__main__":
101
parser = argparse.ArgumentParser()
102
103
parser.add_argument(
104
"--model_path",
105
type=str,
106
required=True,
107
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
108
)
109
110
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
111
parser.add_argument(
112
"--opset",
113
default=14,
114
type=int,
115
help="The version of the ONNX operator set to use.",
116
)
117
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
118
119
args = parser.parse_args()
120
print(args.output_path)
121
convert_models(args.model_path, args.output_path, args.opset, args.fp16)
122
print("SD: Done: ONNX")
123
124