Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/ckpt.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/ckpt.py
6
7
from os.path import join
8
import os
9
import glob
10
11
import torch
12
import numpy as np
13
14
import utils.log as log
15
try:
16
import utils.misc as misc
17
except AttributeError:
18
pass
19
20
blacklist = ["CCMGAN2048-train-2021_06_22_06_11_37"]
21
22
23
def make_ckpt_dir(ckpt_dir):
24
if not os.path.exists(ckpt_dir):
25
os.makedirs(ckpt_dir)
26
return ckpt_dir
27
28
29
def load_ckpt(model, optimizer, ckpt_path, load_model=False, load_opt=False, load_misc=False, is_freezeD=False):
30
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
31
if load_model:
32
if is_freezeD:
33
mismatch_names = misc.load_parameters(src=ckpt["state_dict"],
34
dst=model.state_dict(),
35
strict=False)
36
print("The following parameters/buffers do not match with the ones of the pre-trained model:", mismatch_names)
37
else:
38
model.load_state_dict(ckpt["state_dict"], strict=True)
39
40
if load_opt:
41
optimizer.load_state_dict(ckpt["optimizer"])
42
for state in optimizer.state.values():
43
for k, v in state.items():
44
if isinstance(v, torch.Tensor):
45
state[k] = v.cuda()
46
47
if load_misc:
48
seed = ckpt["seed"]
49
run_name = ckpt["run_name"]
50
step = ckpt["step"]
51
try:
52
aa_p = ckpt["aa_p"]
53
except:
54
aa_p = ckpt["ada_p"]
55
best_step = ckpt["best_step"]
56
best_fid = ckpt["best_fid"]
57
58
try:
59
epoch = ckpt["epoch"]
60
except:
61
epoch = 0
62
try:
63
topk = ckpt["topk"]
64
except:
65
topk = "initialize"
66
try:
67
best_ckpt_path = ckpt["best_fid_checkpoint_path"]
68
except:
69
best_ckpt_path = ckpt["best_fid_ckpt"]
70
try:
71
lecam_emas = ckpt["lecam_emas"]
72
except:
73
lecam_emas = None
74
return seed, run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, lecam_emas
75
76
77
def load_StudioGAN_ckpts(ckpt_dir, load_best, Gen, Dis, g_optimizer, d_optimizer, run_name, apply_g_ema, Gen_ema, ema,
78
is_train, RUN, logger, global_rank, device, cfg_file):
79
when = "best" if load_best is True else "current"
80
x = join(ckpt_dir, "model=G-{when}-weights-step=".format(when=when))
81
Gen_ckpt_path = glob.glob(glob.escape(x) + '*.pth')[0]
82
y = join(ckpt_dir, "model=D-{when}-weights-step=".format(when=when))
83
Dis_ckpt_path = glob.glob(glob.escape(y) + '*.pth')[0]
84
85
prev_run_name = torch.load(Dis_ckpt_path, map_location=lambda storage, loc: storage)["run_name"]
86
is_freezeD = True if RUN.freezeD > -1 else False
87
88
load_ckpt(model=Gen,
89
optimizer=g_optimizer,
90
ckpt_path=Gen_ckpt_path,
91
load_model=True,
92
load_opt=False if prev_run_name in blacklist or is_freezeD or not is_train else True,
93
load_misc=False,
94
is_freezeD=is_freezeD)
95
96
seed, prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, lecam_emas =\
97
load_ckpt(model=Dis,
98
optimizer=d_optimizer,
99
ckpt_path=Dis_ckpt_path,
100
load_model=True,
101
load_opt=False if prev_run_name in blacklist or is_freezeD or not is_train else True,
102
load_misc=True,
103
is_freezeD=is_freezeD)
104
105
if apply_g_ema:
106
z = join(ckpt_dir, "model=G_ema-{when}-weights-step=".format(when=when))
107
Gen_ema_ckpt_path = glob.glob(glob.escape(z) + '*.pth')[0]
108
load_ckpt(model=Gen_ema,
109
optimizer=None,
110
ckpt_path=Gen_ema_ckpt_path,
111
load_model=True,
112
load_opt=False,
113
load_misc=False,
114
is_freezeD=is_freezeD)
115
116
ema.source, ema.target = Gen, Gen_ema
117
118
if is_train and RUN.seed != seed:
119
RUN.seed = seed + global_rank
120
misc.fix_seed(RUN.seed)
121
122
if device == 0:
123
if not is_freezeD:
124
logger = log.make_logger(RUN.save_dir, prev_run_name, None)
125
126
logger.info("Generator checkpoint is {}".format(Gen_ckpt_path))
127
if apply_g_ema:
128
logger.info("EMA_Generator checkpoint is {}".format(Gen_ema_ckpt_path))
129
logger.info("Discriminator checkpoint is {}".format(Dis_ckpt_path))
130
131
if is_freezeD:
132
prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path =\
133
run_name, 0, 0, "initialize", None, 0, None, None
134
return prev_run_name, step, epoch, topk, aa_p, best_step, best_fid, best_ckpt_path, lecam_emas, logger
135
136
137
def load_best_model(ckpt_dir, Gen, Dis, apply_g_ema, Gen_ema, ema):
138
Gen, Dis, Gen_ema = misc.peel_models(Gen, Dis, Gen_ema)
139
Gen_ckpt_path = glob.glob(join(ckpt_dir, "model=G-best-weights-step*.pth"))[0]
140
Dis_ckpt_path = glob.glob(join(ckpt_dir, "model=D-best-weights-step*.pth"))[0]
141
142
load_ckpt(model=Gen,
143
optimizer=None,
144
ckpt_path=Gen_ckpt_path,
145
load_model=True,
146
load_opt=False,
147
load_misc=False,
148
is_freezeD=False)
149
150
151
_, _, _, _, _, _, best_step, _, _, _ = load_ckpt(model=Dis,
152
optimizer=None,
153
ckpt_path=Dis_ckpt_path,
154
load_model=True,
155
load_opt=False,
156
load_misc=True,
157
is_freezeD=False)
158
159
if apply_g_ema:
160
Gen_ema_ckpt_path = glob.glob(join(ckpt_dir, "model=G_ema-best-weights-step*.pth"))[0]
161
load_ckpt(model=Gen_ema,
162
optimizer=None,
163
ckpt_path=Gen_ema_ckpt_path,
164
load_model=True,
165
load_opt=False,
166
load_misc=False,
167
is_freezeD=False)
168
169
ema.source, ema.target = Gen, Gen_ema
170
return best_step
171
172
173
def load_prev_dict(directory, file_name):
174
return np.load(join(directory, file_name), allow_pickle=True).item()
175
176
177
def check_is_pre_trained_model(ckpt_dir, GAN_train, GAN_test):
178
assert GAN_train*GAN_test == 0, "cannot conduct GAN_train and GAN_test togather."
179
if GAN_train:
180
mode = "fake_trained"
181
else:
182
mode = "real_trained"
183
184
ckpt_list = glob.glob(join(ckpt_dir, "model=C-{mode}-best-weights.pth".format(mode=mode)))
185
if len(ckpt_list) == 0:
186
is_pre_train_model = False
187
else:
188
is_pre_train_model = True
189
return is_pre_train_model, mode
190
191
192
def load_GAN_train_test_model(model, mode, optimizer, RUN):
193
ckpt_path = join(RUN.ckpt_dir, "model=C-{mode}-best-weights.pth".format(mode=mode))
194
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
195
196
model.load_state_dict(ckpt["state_dict"])
197
optimizer.load_state_dict(ckpt["optimizer"])
198
epoch_trained = ckpt["epoch"]
199
best_top1 = ckpt["best_top1"]
200
best_top5 = ckpt["best_top5"]
201
best_epoch = ckpt["best_epoch"]
202
return epoch_trained, best_top1, best_top5, best_epoch
203
204