Path: blob/main/scripts/convert_original_controlnet_to_diffusers.py
1440 views
# coding=utf-81# Copyright 2023 The HuggingFace Inc. team.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """1516import argparse1718from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt192021if __name__ == "__main__":22parser = argparse.ArgumentParser()2324parser.add_argument(25"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."26)27parser.add_argument(28"--original_config_file",29type=str,30required=True,31help="The YAML config file corresponding to the original architecture.",32)33parser.add_argument(34"--num_in_channels",35default=None,36type=int,37help="The number of input channels. If `None` number of input channels will be automatically inferred.",38)39parser.add_argument(40"--image_size",41default=512,42type=int,43help=(44"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"45" Base. Use 768 for Stable Diffusion v2."46),47)48parser.add_argument(49"--extract_ema",50action="store_true",51help=(52"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"53" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"54" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."55),56)57parser.add_argument(58"--upcast_attention",59action="store_true",60help=(61"Whether the attention computation should always be upcasted. This is necessary when running stable"62" diffusion 2.1."63),64)65parser.add_argument(66"--from_safetensors",67action="store_true",68help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",69)70parser.add_argument(71"--to_safetensors",72action="store_true",73help="Whether to store pipeline in safetensors format or not.",74)75parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")76parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")77args = parser.parse_args()7879controlnet = download_controlnet_from_original_ckpt(80checkpoint_path=args.checkpoint_path,81original_config_file=args.original_config_file,82image_size=args.image_size,83extract_ema=args.extract_ema,84num_in_channels=args.num_in_channels,85upcast_attention=args.upcast_attention,86from_safetensors=args.from_safetensors,87device=args.device,88)8990controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)919293