Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/det.py
540 views
1
#Adapted from A1111
2
import argparse
3
import torch
4
import open_clip
5
import transformers.utils.hub
6
from safetensors import safe_open
7
import os
8
import sys
9
import wget
10
from subprocess import call
11
12
parser = argparse.ArgumentParser()
13
parser.add_argument("--MODEL_PATH", type=str)
14
parser.add_argument("--from_safetensors", action='store_true')
15
args = parser.parse_args()
16
17
wget.download("https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dreambooth/ldm.zip")
18
call('unzip ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
19
call('rm ldm.zip', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
20
21
import ldm.modules.diffusionmodules.openaimodel
22
import ldm.modules.encoders.modules
23
24
class DisableInitialization:
25
26
def __init__(self, disable_clip=True):
27
self.replaced = []
28
self.disable_clip = disable_clip
29
30
def replace(self, obj, field, func):
31
original = getattr(obj, field, None)
32
if original is None:
33
return None
34
35
self.replaced.append((obj, field, original))
36
setattr(obj, field, func)
37
38
return original
39
40
def __enter__(self):
41
def do_nothing(*args, **kwargs):
42
pass
43
44
def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
45
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
46
47
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
48
res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
49
res.name_or_path = pretrained_model_name_or_path
50
return res
51
52
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
53
args = args[0:3] + ('/', ) + args[4:]
54
return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
55
56
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
57
58
if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
59
return None
60
61
try:
62
res = original(url, *args, local_files_only=True, **kwargs)
63
if res is None:
64
res = original(url, *args, local_files_only=False, **kwargs)
65
return res
66
except Exception as e:
67
return original(url, *args, local_files_only=False, **kwargs)
68
69
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
70
return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
71
72
def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
73
return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
74
75
def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
76
return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
77
78
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
79
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
80
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
81
82
if self.disable_clip:
83
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
84
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
85
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
86
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
87
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
88
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
89
90
def __exit__(self, exc_type, exc_val, exc_tb):
91
for obj, field, original in self.replaced:
92
setattr(obj, field, original)
93
94
self.replaced.clear()
95
96
97
def vpar(state_dict):
98
99
device = torch.device("cuda")
100
101
with DisableInitialization():
102
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
103
use_checkpoint=True,
104
use_fp16=False,
105
image_size=32,
106
in_channels=4,
107
out_channels=4,
108
model_channels=320,
109
attention_resolutions=[4, 2, 1],
110
num_res_blocks=2,
111
channel_mult=[1, 2, 4, 4],
112
num_head_channels=64,
113
use_spatial_transformer=True,
114
use_linear_in_transformer=True,
115
transformer_depth=1,
116
context_dim=1024,
117
legacy=False
118
)
119
unet.eval()
120
121
with torch.no_grad():
122
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
123
unet.load_state_dict(unet_sd, strict=True)
124
unet.to(device=device, dtype=torch.float)
125
126
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
127
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
128
129
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
130
131
return out < -1
132
133
134
def detect_version(sd):
135
136
sys.stdout = open(os.devnull, 'w')
137
138
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
139
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
140
141
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
142
143
if vpar(sd):
144
sys.stdout = sys.__stdout__
145
sd2_v=print("V2.1-768px")
146
return sd2_v
147
else:
148
sys.stdout = sys.__stdout__
149
sd2=print("V2.1-512px")
150
return sd2
151
152
else:
153
sys.stdout = sys.__stdout__
154
v1=print("1.5")
155
return v1
156
157
158
if args.from_safetensors:
159
160
checkpoint = {}
161
with safe_open(args.MODEL_PATH, framework="pt", device="cuda") as f:
162
for key in f.keys():
163
checkpoint[key] = f.get_tensor(key)
164
state_dict = checkpoint
165
else:
166
checkpoint = torch.load(args.MODEL_PATH, map_location="cuda")
167
state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
168
169
detect_version(state_dict)
170
171
call('rm -r ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
172
173