Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/checkpoint_merger.py
1448 views
1
import glob
2
import os
3
from typing import Dict, List, Union
4
5
import torch
6
7
from diffusers.utils import is_safetensors_available
8
9
10
if is_safetensors_available():
11
import safetensors.torch
12
13
from huggingface_hub import snapshot_download
14
15
from diffusers import DiffusionPipeline, __version__
16
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
17
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
18
19
20
class CheckpointMergerPipeline(DiffusionPipeline):
21
"""
22
A class that that supports merging diffusion models based on the discussion here:
23
https://github.com/huggingface/diffusers/issues/877
24
25
Example usage:-
26
27
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
28
29
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
30
31
merged_pipe.to('cuda')
32
33
prompt = "An astronaut riding a unicycle on Mars"
34
35
results = merged_pipe(prompt)
36
37
## For more details, see the docstring for the merge method.
38
39
"""
40
41
def __init__(self):
42
self.register_to_config()
43
super().__init__()
44
45
def _compare_model_configs(self, dict0, dict1):
46
if dict0 == dict1:
47
return True
48
else:
49
config0, meta_keys0 = self._remove_meta_keys(dict0)
50
config1, meta_keys1 = self._remove_meta_keys(dict1)
51
if config0 == config1:
52
print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
53
return True
54
return False
55
56
def _remove_meta_keys(self, config_dict: Dict):
57
meta_keys = []
58
temp_dict = config_dict.copy()
59
for key in config_dict.keys():
60
if key.startswith("_"):
61
temp_dict.pop(key)
62
meta_keys.append(key)
63
return (temp_dict, meta_keys)
64
65
@torch.no_grad()
66
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
67
"""
68
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
69
in the argument 'pretrained_model_name_or_path_list' as a list.
70
71
Parameters:
72
-----------
73
pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
74
75
**kwargs:
76
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
77
78
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
79
80
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
81
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
82
83
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
84
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
85
86
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
87
88
"""
89
# Default kwargs from DiffusionPipeline
90
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
91
resume_download = kwargs.pop("resume_download", False)
92
force_download = kwargs.pop("force_download", False)
93
proxies = kwargs.pop("proxies", None)
94
local_files_only = kwargs.pop("local_files_only", False)
95
use_auth_token = kwargs.pop("use_auth_token", None)
96
revision = kwargs.pop("revision", None)
97
torch_dtype = kwargs.pop("torch_dtype", None)
98
device_map = kwargs.pop("device_map", None)
99
100
alpha = kwargs.pop("alpha", 0.5)
101
interp = kwargs.pop("interp", None)
102
103
print("Received list", pretrained_model_name_or_path_list)
104
print(f"Combining with alpha={alpha}, interpolation mode={interp}")
105
106
checkpoint_count = len(pretrained_model_name_or_path_list)
107
# Ignore result from model_index_json comparision of the two checkpoints
108
force = kwargs.pop("force", False)
109
110
# If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
111
if checkpoint_count > 3 or checkpoint_count < 2:
112
raise ValueError(
113
"Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
114
" passed."
115
)
116
117
print("Received the right number of checkpoints")
118
# chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
119
# chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
120
121
# Validate that the checkpoints can be merged
122
# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
123
config_dicts = []
124
for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
125
config_dict = DiffusionPipeline.load_config(
126
pretrained_model_name_or_path,
127
cache_dir=cache_dir,
128
resume_download=resume_download,
129
force_download=force_download,
130
proxies=proxies,
131
local_files_only=local_files_only,
132
use_auth_token=use_auth_token,
133
revision=revision,
134
)
135
config_dicts.append(config_dict)
136
137
comparison_result = True
138
for idx in range(1, len(config_dicts)):
139
comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
140
if not force and comparison_result is False:
141
raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
142
print(config_dicts[0], config_dicts[1])
143
print("Compatible model_index.json files found")
144
# Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
145
cached_folders = []
146
for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
147
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
148
allow_patterns = [os.path.join(k, "*") for k in folder_names]
149
allow_patterns += [
150
WEIGHTS_NAME,
151
SCHEDULER_CONFIG_NAME,
152
CONFIG_NAME,
153
ONNX_WEIGHTS_NAME,
154
DiffusionPipeline.config_name,
155
]
156
requested_pipeline_class = config_dict.get("_class_name")
157
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
158
159
cached_folder = (
160
pretrained_model_name_or_path
161
if os.path.isdir(pretrained_model_name_or_path)
162
else snapshot_download(
163
pretrained_model_name_or_path,
164
cache_dir=cache_dir,
165
resume_download=resume_download,
166
proxies=proxies,
167
local_files_only=local_files_only,
168
use_auth_token=use_auth_token,
169
revision=revision,
170
allow_patterns=allow_patterns,
171
user_agent=user_agent,
172
)
173
)
174
print("Cached Folder", cached_folder)
175
cached_folders.append(cached_folder)
176
177
# Step 3:-
178
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
179
final_pipe = DiffusionPipeline.from_pretrained(
180
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
181
)
182
final_pipe.to(self.device)
183
184
checkpoint_path_2 = None
185
if len(cached_folders) > 2:
186
checkpoint_path_2 = os.path.join(cached_folders[2])
187
188
if interp == "sigmoid":
189
theta_func = CheckpointMergerPipeline.sigmoid
190
elif interp == "inv_sigmoid":
191
theta_func = CheckpointMergerPipeline.inv_sigmoid
192
elif interp == "add_diff":
193
theta_func = CheckpointMergerPipeline.add_difference
194
else:
195
theta_func = CheckpointMergerPipeline.weighted_sum
196
197
# Find each module's state dict.
198
for attr in final_pipe.config.keys():
199
if not attr.startswith("_"):
200
checkpoint_path_1 = os.path.join(cached_folders[1], attr)
201
if os.path.exists(checkpoint_path_1):
202
files = list(
203
(
204
*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
205
*glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
206
)
207
)
208
checkpoint_path_1 = files[0] if len(files) > 0 else None
209
if len(cached_folders) < 3:
210
checkpoint_path_2 = None
211
else:
212
checkpoint_path_2 = os.path.join(cached_folders[2], attr)
213
if os.path.exists(checkpoint_path_2):
214
files = list(
215
(
216
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
217
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
218
)
219
)
220
checkpoint_path_2 = files[0] if len(files) > 0 else None
221
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
222
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
223
if checkpoint_path_1 is None and checkpoint_path_2 is None:
224
print(f"Skipping {attr}: not present in 2nd or 3d model")
225
continue
226
try:
227
module = getattr(final_pipe, attr)
228
if isinstance(module, bool): # ignore requires_safety_checker boolean
229
continue
230
theta_0 = getattr(module, "state_dict")
231
theta_0 = theta_0()
232
233
update_theta_0 = getattr(module, "load_state_dict")
234
theta_1 = (
235
safetensors.torch.load_file(checkpoint_path_1)
236
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))
237
else torch.load(checkpoint_path_1, map_location="cpu")
238
)
239
theta_2 = None
240
if checkpoint_path_2:
241
theta_2 = (
242
safetensors.torch.load_file(checkpoint_path_2)
243
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
244
else torch.load(checkpoint_path_2, map_location="cpu")
245
)
246
247
if not theta_0.keys() == theta_1.keys():
248
print(f"Skipping {attr}: key mismatch")
249
continue
250
if theta_2 and not theta_1.keys() == theta_2.keys():
251
print(f"Skipping {attr}:y mismatch")
252
except Exception as e:
253
print(f"Skipping {attr} do to an unexpected error: {str(e)}")
254
continue
255
print(f"MERGING {attr}")
256
257
for key in theta_0.keys():
258
if theta_2:
259
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
260
else:
261
theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
262
263
del theta_1
264
del theta_2
265
update_theta_0(theta_0)
266
267
del theta_0
268
return final_pipe
269
270
@staticmethod
271
def weighted_sum(theta0, theta1, theta2, alpha):
272
return ((1 - alpha) * theta0) + (alpha * theta1)
273
274
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
275
@staticmethod
276
def sigmoid(theta0, theta1, theta2, alpha):
277
alpha = alpha * alpha * (3 - (2 * alpha))
278
return theta0 + ((theta1 - theta0) * alpha)
279
280
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
281
@staticmethod
282
def inv_sigmoid(theta0, theta1, theta2, alpha):
283
import math
284
285
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
286
return theta0 + ((theta1 - theta0) * alpha)
287
288
@staticmethod
289
def add_difference(theta0, theta1, theta2, alpha):
290
return theta0 + (theta1 - theta2) * (1.0 - alpha)
291
292