Path: blob/main/examples/community/checkpoint_merger.py
1448 views
import glob1import os2from typing import Dict, List, Union34import torch56from diffusers.utils import is_safetensors_available789if is_safetensors_available():10import safetensors.torch1112from huggingface_hub import snapshot_download1314from diffusers import DiffusionPipeline, __version__15from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME16from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME171819class CheckpointMergerPipeline(DiffusionPipeline):20"""21A class that that supports merging diffusion models based on the discussion here:22https://github.com/huggingface/diffusers/issues/8772324Example usage:-2526pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")2728merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)2930merged_pipe.to('cuda')3132prompt = "An astronaut riding a unicycle on Mars"3334results = merged_pipe(prompt)3536## For more details, see the docstring for the merge method.3738"""3940def __init__(self):41self.register_to_config()42super().__init__()4344def _compare_model_configs(self, dict0, dict1):45if dict0 == dict1:46return True47else:48config0, meta_keys0 = self._remove_meta_keys(dict0)49config1, meta_keys1 = self._remove_meta_keys(dict1)50if config0 == config1:51print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")52return True53return False5455def _remove_meta_keys(self, config_dict: Dict):56meta_keys = []57temp_dict = config_dict.copy()58for key in config_dict.keys():59if key.startswith("_"):60temp_dict.pop(key)61meta_keys.append(key)62return (temp_dict, meta_keys)6364@torch.no_grad()65def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):66"""67Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed68in the argument 'pretrained_model_name_or_path_list' as a list.6970Parameters:71-----------72pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.7374**kwargs:75Supports all the default DiffusionPipeline.get_config_dict kwargs viz..7677cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.7879alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha80would mean that the first model checkpoints would affect the final result far less than an alpha of 0.28182interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.83Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.8485force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.8687"""88# Default kwargs from DiffusionPipeline89cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)90resume_download = kwargs.pop("resume_download", False)91force_download = kwargs.pop("force_download", False)92proxies = kwargs.pop("proxies", None)93local_files_only = kwargs.pop("local_files_only", False)94use_auth_token = kwargs.pop("use_auth_token", None)95revision = kwargs.pop("revision", None)96torch_dtype = kwargs.pop("torch_dtype", None)97device_map = kwargs.pop("device_map", None)9899alpha = kwargs.pop("alpha", 0.5)100interp = kwargs.pop("interp", None)101102print("Received list", pretrained_model_name_or_path_list)103print(f"Combining with alpha={alpha}, interpolation mode={interp}")104105checkpoint_count = len(pretrained_model_name_or_path_list)106# Ignore result from model_index_json comparision of the two checkpoints107force = kwargs.pop("force", False)108109# If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.110if checkpoint_count > 3 or checkpoint_count < 2:111raise ValueError(112"Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"113" passed."114)115116print("Received the right number of checkpoints")117# chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]118# chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None119120# Validate that the checkpoints can be merged121# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'122config_dicts = []123for pretrained_model_name_or_path in pretrained_model_name_or_path_list:124config_dict = DiffusionPipeline.load_config(125pretrained_model_name_or_path,126cache_dir=cache_dir,127resume_download=resume_download,128force_download=force_download,129proxies=proxies,130local_files_only=local_files_only,131use_auth_token=use_auth_token,132revision=revision,133)134config_dicts.append(config_dict)135136comparison_result = True137for idx in range(1, len(config_dicts)):138comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])139if not force and comparison_result is False:140raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")141print(config_dicts[0], config_dicts[1])142print("Compatible model_index.json files found")143# Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.144cached_folders = []145for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):146folder_names = [k for k in config_dict.keys() if not k.startswith("_")]147allow_patterns = [os.path.join(k, "*") for k in folder_names]148allow_patterns += [149WEIGHTS_NAME,150SCHEDULER_CONFIG_NAME,151CONFIG_NAME,152ONNX_WEIGHTS_NAME,153DiffusionPipeline.config_name,154]155requested_pipeline_class = config_dict.get("_class_name")156user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}157158cached_folder = (159pretrained_model_name_or_path160if os.path.isdir(pretrained_model_name_or_path)161else snapshot_download(162pretrained_model_name_or_path,163cache_dir=cache_dir,164resume_download=resume_download,165proxies=proxies,166local_files_only=local_files_only,167use_auth_token=use_auth_token,168revision=revision,169allow_patterns=allow_patterns,170user_agent=user_agent,171)172)173print("Cached Folder", cached_folder)174cached_folders.append(cached_folder)175176# Step 3:-177# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place178final_pipe = DiffusionPipeline.from_pretrained(179cached_folders[0], torch_dtype=torch_dtype, device_map=device_map180)181final_pipe.to(self.device)182183checkpoint_path_2 = None184if len(cached_folders) > 2:185checkpoint_path_2 = os.path.join(cached_folders[2])186187if interp == "sigmoid":188theta_func = CheckpointMergerPipeline.sigmoid189elif interp == "inv_sigmoid":190theta_func = CheckpointMergerPipeline.inv_sigmoid191elif interp == "add_diff":192theta_func = CheckpointMergerPipeline.add_difference193else:194theta_func = CheckpointMergerPipeline.weighted_sum195196# Find each module's state dict.197for attr in final_pipe.config.keys():198if not attr.startswith("_"):199checkpoint_path_1 = os.path.join(cached_folders[1], attr)200if os.path.exists(checkpoint_path_1):201files = list(202(203*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),204*glob.glob(os.path.join(checkpoint_path_1, "*.bin")),205)206)207checkpoint_path_1 = files[0] if len(files) > 0 else None208if len(cached_folders) < 3:209checkpoint_path_2 = None210else:211checkpoint_path_2 = os.path.join(cached_folders[2], attr)212if os.path.exists(checkpoint_path_2):213files = list(214(215*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),216*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),217)218)219checkpoint_path_2 = files[0] if len(files) > 0 else None220# For an attr if both checkpoint_path_1 and 2 are None, ignore.221# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.222if checkpoint_path_1 is None and checkpoint_path_2 is None:223print(f"Skipping {attr}: not present in 2nd or 3d model")224continue225try:226module = getattr(final_pipe, attr)227if isinstance(module, bool): # ignore requires_safety_checker boolean228continue229theta_0 = getattr(module, "state_dict")230theta_0 = theta_0()231232update_theta_0 = getattr(module, "load_state_dict")233theta_1 = (234safetensors.torch.load_file(checkpoint_path_1)235if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))236else torch.load(checkpoint_path_1, map_location="cpu")237)238theta_2 = None239if checkpoint_path_2:240theta_2 = (241safetensors.torch.load_file(checkpoint_path_2)242if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))243else torch.load(checkpoint_path_2, map_location="cpu")244)245246if not theta_0.keys() == theta_1.keys():247print(f"Skipping {attr}: key mismatch")248continue249if theta_2 and not theta_1.keys() == theta_2.keys():250print(f"Skipping {attr}:y mismatch")251except Exception as e:252print(f"Skipping {attr} do to an unexpected error: {str(e)}")253continue254print(f"MERGING {attr}")255256for key in theta_0.keys():257if theta_2:258theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)259else:260theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)261262del theta_1263del theta_2264update_theta_0(theta_0)265266del theta_0267return final_pipe268269@staticmethod270def weighted_sum(theta0, theta1, theta2, alpha):271return ((1 - alpha) * theta0) + (alpha * theta1)272273# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)274@staticmethod275def sigmoid(theta0, theta1, theta2, alpha):276alpha = alpha * alpha * (3 - (2 * alpha))277return theta0 + ((theta1 - theta0) * alpha)278279# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)280@staticmethod281def inv_sigmoid(theta0, theta1, theta2, alpha):282import math283284alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)285return theta0 + ((theta1 - theta0) * alpha)286287@staticmethod288def add_difference(theta0, theta1, theta2, alpha):289return theta0 + (theta1 - theta2) * (1.0 - alpha)290291292