Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
1440 views
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import argparse
16
import os
17
import shutil
18
from pathlib import Path
19
20
import onnx
21
import torch
22
from packaging import version
23
from torch.onnx import export
24
25
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
26
27
28
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
29
30
31
def onnx_export(
32
model,
33
model_args: tuple,
34
output_path: Path,
35
ordered_input_names,
36
output_names,
37
dynamic_axes,
38
opset,
39
use_external_data_format=False,
40
):
41
output_path.parent.mkdir(parents=True, exist_ok=True)
42
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
43
# so we check the torch version for backwards compatibility
44
if is_torch_less_than_1_11:
45
export(
46
model,
47
model_args,
48
f=output_path.as_posix(),
49
input_names=ordered_input_names,
50
output_names=output_names,
51
dynamic_axes=dynamic_axes,
52
do_constant_folding=True,
53
use_external_data_format=use_external_data_format,
54
enable_onnx_checker=True,
55
opset_version=opset,
56
)
57
else:
58
export(
59
model,
60
model_args,
61
f=output_path.as_posix(),
62
input_names=ordered_input_names,
63
output_names=output_names,
64
dynamic_axes=dynamic_axes,
65
do_constant_folding=True,
66
opset_version=opset,
67
)
68
69
70
@torch.no_grad()
71
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
72
dtype = torch.float16 if fp16 else torch.float32
73
if fp16 and torch.cuda.is_available():
74
device = "cuda"
75
elif fp16 and not torch.cuda.is_available():
76
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
77
else:
78
device = "cpu"
79
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
80
output_path = Path(output_path)
81
82
# TEXT ENCODER
83
num_tokens = pipeline.text_encoder.config.max_position_embeddings
84
text_hidden_size = pipeline.text_encoder.config.hidden_size
85
text_input = pipeline.tokenizer(
86
"A sample prompt",
87
padding="max_length",
88
max_length=pipeline.tokenizer.model_max_length,
89
truncation=True,
90
return_tensors="pt",
91
)
92
onnx_export(
93
pipeline.text_encoder,
94
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
95
model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
96
output_path=output_path / "text_encoder" / "model.onnx",
97
ordered_input_names=["input_ids"],
98
output_names=["last_hidden_state", "pooler_output"],
99
dynamic_axes={
100
"input_ids": {0: "batch", 1: "sequence"},
101
},
102
opset=opset,
103
)
104
del pipeline.text_encoder
105
106
# UNET
107
unet_in_channels = pipeline.unet.config.in_channels
108
unet_sample_size = pipeline.unet.config.sample_size
109
unet_path = output_path / "unet" / "model.onnx"
110
onnx_export(
111
pipeline.unet,
112
model_args=(
113
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
114
torch.randn(2).to(device=device, dtype=dtype),
115
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
116
False,
117
),
118
output_path=unet_path,
119
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
120
output_names=["out_sample"], # has to be different from "sample" for correct tracing
121
dynamic_axes={
122
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
123
"timestep": {0: "batch"},
124
"encoder_hidden_states": {0: "batch", 1: "sequence"},
125
},
126
opset=opset,
127
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
128
)
129
unet_model_path = str(unet_path.absolute().as_posix())
130
unet_dir = os.path.dirname(unet_model_path)
131
unet = onnx.load(unet_model_path)
132
# clean up existing tensor files
133
shutil.rmtree(unet_dir)
134
os.mkdir(unet_dir)
135
# collate external tensor files into one
136
onnx.save_model(
137
unet,
138
unet_model_path,
139
save_as_external_data=True,
140
all_tensors_to_one_file=True,
141
location="weights.pb",
142
convert_attribute=False,
143
)
144
del pipeline.unet
145
146
# VAE ENCODER
147
vae_encoder = pipeline.vae
148
vae_in_channels = vae_encoder.config.in_channels
149
vae_sample_size = vae_encoder.config.sample_size
150
# need to get the raw tensor output (sample) from the encoder
151
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
152
onnx_export(
153
vae_encoder,
154
model_args=(
155
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
156
False,
157
),
158
output_path=output_path / "vae_encoder" / "model.onnx",
159
ordered_input_names=["sample", "return_dict"],
160
output_names=["latent_sample"],
161
dynamic_axes={
162
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
163
},
164
opset=opset,
165
)
166
167
# VAE DECODER
168
vae_decoder = pipeline.vae
169
vae_latent_channels = vae_decoder.config.latent_channels
170
vae_out_channels = vae_decoder.config.out_channels
171
# forward only through the decoder part
172
vae_decoder.forward = vae_encoder.decode
173
onnx_export(
174
vae_decoder,
175
model_args=(
176
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
177
False,
178
),
179
output_path=output_path / "vae_decoder" / "model.onnx",
180
ordered_input_names=["latent_sample", "return_dict"],
181
output_names=["sample"],
182
dynamic_axes={
183
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
184
},
185
opset=opset,
186
)
187
del pipeline.vae
188
189
# SAFETY CHECKER
190
if pipeline.safety_checker is not None:
191
safety_checker = pipeline.safety_checker
192
clip_num_channels = safety_checker.config.vision_config.num_channels
193
clip_image_size = safety_checker.config.vision_config.image_size
194
safety_checker.forward = safety_checker.forward_onnx
195
onnx_export(
196
pipeline.safety_checker,
197
model_args=(
198
torch.randn(
199
1,
200
clip_num_channels,
201
clip_image_size,
202
clip_image_size,
203
).to(device=device, dtype=dtype),
204
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
205
),
206
output_path=output_path / "safety_checker" / "model.onnx",
207
ordered_input_names=["clip_input", "images"],
208
output_names=["out_images", "has_nsfw_concepts"],
209
dynamic_axes={
210
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
211
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
212
},
213
opset=opset,
214
)
215
del pipeline.safety_checker
216
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
217
feature_extractor = pipeline.feature_extractor
218
else:
219
safety_checker = None
220
feature_extractor = None
221
222
onnx_pipeline = OnnxStableDiffusionPipeline(
223
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
224
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
225
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
226
tokenizer=pipeline.tokenizer,
227
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
228
scheduler=pipeline.scheduler,
229
safety_checker=safety_checker,
230
feature_extractor=feature_extractor,
231
requires_safety_checker=safety_checker is not None,
232
)
233
234
onnx_pipeline.save_pretrained(output_path)
235
print("ONNX pipeline saved to", output_path)
236
237
del pipeline
238
del onnx_pipeline
239
_ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
240
print("ONNX pipeline is loadable")
241
242
243
if __name__ == "__main__":
244
parser = argparse.ArgumentParser()
245
246
parser.add_argument(
247
"--model_path",
248
type=str,
249
required=True,
250
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
251
)
252
253
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
254
255
parser.add_argument(
256
"--opset",
257
default=14,
258
type=int,
259
help="The version of the ONNX operator set to use.",
260
)
261
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
262
263
args = parser.parse_args()
264
265
convert_models(args.model_path, args.output_path, args.opset, args.fp16)
266
267