Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
POSTECH-CVLab
GitHub Repository: POSTECH-CVLab/PyTorch-StudioGAN
Path: blob/master/src/utils/misc.py
809 views
1
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
2
# The MIT License (MIT)
3
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
4
5
# src/utils/misc.py
6
7
from os.path import dirname, exists, join, isfile
8
from datetime import datetime
9
from collections import defaultdict
10
import random
11
import math
12
import os
13
import sys
14
import glob
15
import json
16
import warnings
17
18
from torch.nn import DataParallel
19
from torchvision.datasets import CIFAR10, CIFAR100
20
from torch.nn.parallel import DistributedDataParallel
21
from torchvision.utils import save_image
22
from itertools import chain
23
from tqdm import tqdm
24
from scipy import linalg
25
import torch
26
import torch.distributed as dist
27
import torch.nn.functional as F
28
import torch.multiprocessing as mp
29
import torchvision.transforms as transforms
30
import shutil
31
import numpy as np
32
import seaborn as sns
33
import matplotlib.pyplot as plt
34
35
import utils.sample as sample
36
import utils.losses as losses
37
import utils.ckpt as ckpt
38
39
40
class make_empty_object(object):
41
pass
42
43
44
class dummy_context_mgr():
45
def __enter__(self):
46
return None
47
48
def __exit__(self, exc_type, exc_value, traceback):
49
return False
50
51
52
class SaveOutput:
53
def __init__(self):
54
self.outputs = []
55
56
def __call__(self, module, module_input):
57
self.outputs.append(module_input)
58
59
def clear(self):
60
self.outputs = []
61
62
63
class GeneratorController(object):
64
def __init__(self, generator, generator_mapping, generator_synthesis, batch_statistics, standing_statistics,
65
standing_max_batch, standing_step, cfgs, device, global_rank, logger, std_stat_counter):
66
self.generator = generator
67
self.generator_mapping = generator_mapping
68
self.generator_synthesis = generator_synthesis
69
self.batch_statistics = batch_statistics
70
self.standing_statistics = standing_statistics
71
self.standing_max_batch = standing_max_batch
72
self.standing_step = standing_step
73
self.cfgs = cfgs
74
self.device = device
75
self.global_rank = global_rank
76
self.logger = logger
77
self.std_stat_counter = std_stat_counter
78
79
def prepare_generator(self):
80
if self.standing_statistics:
81
if self.std_stat_counter > 1:
82
self.generator.eval()
83
self.generator.apply(set_deterministic_op_trainable)
84
else:
85
self.generator.train()
86
apply_standing_statistics(generator=self.generator,
87
standing_max_batch=self.standing_max_batch,
88
standing_step=self.standing_step,
89
DATA=self.cfgs.DATA,
90
MODEL=self.cfgs.MODEL,
91
LOSS=self.cfgs.LOSS,
92
OPTIMIZATION=self.cfgs.OPTIMIZATION,
93
RUN=self.cfgs.RUN,
94
STYLEGAN=self.cfgs.STYLEGAN,
95
device=self.device,
96
global_rank=self.global_rank,
97
logger=self.logger)
98
self.generator.eval()
99
self.generator.apply(set_deterministic_op_trainable)
100
else:
101
self.generator.eval()
102
if self.batch_statistics:
103
self.generator.apply(set_bn_trainable)
104
self.generator.apply(untrack_bn_statistics)
105
self.generator.apply(set_deterministic_op_trainable)
106
return self.generator, self.generator_mapping, self.generator_synthesis
107
108
109
class AverageMeter(object):
110
"""Computes and stores the average and current value"""
111
def __init__(self):
112
self.reset()
113
114
def reset(self):
115
self.val = 0
116
self.avg = 0
117
self.sum = 0
118
self.count = 0
119
120
def update(self, val, n=1):
121
self.val = val
122
self.sum += val * n
123
self.count += n
124
self.avg = self.sum / self.count
125
126
127
def accuracy(output, target, topk=(1,)):
128
"""Computes the precision@k for the specified values of k"""
129
maxk = max(topk)
130
batch_size = target.size(0)
131
132
_, pred = output.topk(maxk, 1, True, True)
133
pred = pred.t()
134
correct = pred.eq(target.view(1, -1).expand_as(pred))
135
136
res = []
137
for k in topk:
138
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
139
wrong_k = batch_size - correct_k
140
res.append(100 - wrong_k.mul_(100.0 / batch_size))
141
return res
142
143
144
def prepare_folder(names, save_dir):
145
for name in names:
146
folder_path = join(save_dir, name)
147
if not exists(folder_path):
148
os.makedirs(folder_path)
149
150
151
def download_data_if_possible(data_name, data_dir):
152
if data_name == "CIFAR10":
153
data = CIFAR10(root=data_dir, train=True, download=True)
154
elif data_name == "CIFAR100":
155
data = CIFAR100(root=data_dir, train=True, download=True)
156
157
158
def fix_seed(seed):
159
random.seed(seed)
160
torch.manual_seed(seed)
161
torch.cuda.manual_seed_all(seed)
162
torch.cuda.manual_seed(seed)
163
np.random.seed(seed)
164
165
166
def setup(rank, world_size, backend="nccl"):
167
if sys.platform == "win32":
168
# Distributed package only covers collective communications with Gloo
169
# backend and FileStore on Windows platform. Set init_method parameter
170
# in init_process_group to a local file.
171
# Example init_method="file:///f:/libtmp/some_file"
172
init_method = "file:///{your local file path}"
173
174
# initialize the process group
175
dist.init_process_group(backend, init_method=init_method, rank=rank, world_size=world_size)
176
else:
177
# initialize the process group
178
dist.init_process_group(backend,
179
init_method="env://",
180
rank=rank,
181
world_size=world_size)
182
183
184
def cleanup():
185
dist.destroy_process_group()
186
187
188
def count_parameters(module):
189
return "Number of parameters: {num}".format(num=sum([p.data.nelement() for p in module.parameters()]))
190
191
192
def toggle_grad(model, grad, num_freeze_layers=-1, is_stylegan=False):
193
model = peel_model(model)
194
if is_stylegan:
195
for name, param in model.named_parameters():
196
param.requires_grad = grad
197
else:
198
try:
199
num_blocks = len(model.in_dims)
200
assert num_freeze_layers < num_blocks,\
201
"cannot freeze the {nfl}th block > total {nb} blocks.".format(nfl=num_freeze_layers,
202
nb=num_blocks)
203
except:
204
pass
205
206
if num_freeze_layers == -1:
207
for name, param in model.named_parameters():
208
param.requires_grad = grad
209
else:
210
assert grad, "cannot freeze the model when grad is False"
211
for name, param in model.named_parameters():
212
param.requires_grad = True
213
for layer in range(num_freeze_layers):
214
block_name = "blocks.{layer}".format(layer=layer)
215
if block_name in name:
216
param.requires_grad = False
217
218
219
def load_log_dicts(directory, file_name, ph):
220
try:
221
log_dict = ckpt.load_prev_dict(directory=directory, file_name=file_name)
222
except:
223
log_dict = ph
224
return log_dict
225
226
227
def make_model_require_grad(model):
228
if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
229
model = model.module
230
231
for name, param in model.named_parameters():
232
param.requires_grad = True
233
234
235
def identity(x):
236
return x
237
238
239
def set_bn_trainable(m):
240
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
241
m.train()
242
243
244
def untrack_bn_statistics(m):
245
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
246
m.track_running_stats = False
247
248
249
def track_bn_statistics(m):
250
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
251
m.track_running_stats = True
252
253
254
def set_deterministic_op_trainable(m):
255
if isinstance(m, torch.nn.modules.conv.Conv2d):
256
m.train()
257
if isinstance(m, torch.nn.modules.conv.ConvTranspose2d):
258
m.train()
259
if isinstance(m, torch.nn.modules.linear.Linear):
260
m.train()
261
if isinstance(m, torch.nn.modules.Embedding):
262
m.train()
263
264
265
def reset_bn_statistics(m):
266
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
267
m.reset_running_stats()
268
269
270
def elapsed_time(start_time):
271
now = datetime.now()
272
elapsed = now - start_time
273
return str(elapsed).split(".")[0] # remove milliseconds
274
275
276
def reshape_weight_to_matrix(weight):
277
weight_mat = weight
278
dim = 0
279
if dim != 0:
280
weight_mat = weight_mat.permute(dim, *[d for d in range(weight_mat.dim()) if d != dim])
281
height = weight_mat.size(0)
282
return weight_mat.reshape(height, -1)
283
284
285
def calculate_all_sn(model, prefix):
286
sigmas = {}
287
with torch.no_grad():
288
for name, param in model.named_parameters():
289
operations = model
290
if "weight_orig" in name:
291
splited_name = name.split(".")
292
for name_element in splited_name[:-1]:
293
operations = getattr(operations, name_element)
294
weight_orig = reshape_weight_to_matrix(operations.weight_orig)
295
weight_u = operations.weight_u
296
weight_v = operations.weight_v
297
sigmas[prefix + "_" + name] = torch.dot(weight_u, torch.mv(weight_orig, weight_v)).item()
298
return sigmas
299
300
301
def apply_standing_statistics(generator, standing_max_batch, standing_step, DATA, MODEL, LOSS, OPTIMIZATION, RUN, STYLEGAN,
302
device, global_rank, logger):
303
generator.train()
304
generator.apply(reset_bn_statistics)
305
if global_rank == 0:
306
logger.info("Acuumulate statistics of batchnorm layers to improve generation performance.")
307
for i in tqdm(range(standing_step)):
308
batch_size_per_gpu = standing_max_batch // OPTIMIZATION.world_size
309
if RUN.distributed_data_parallel:
310
rand_batch_size = random.randint(1, batch_size_per_gpu)
311
else:
312
rand_batch_size = random.randint(1, batch_size_per_gpu) * OPTIMIZATION.world_size
313
fake_images, fake_labels, _, _, _, _, _ = sample.generate_images(z_prior=MODEL.z_prior,
314
truncation_factor=-1,
315
batch_size=rand_batch_size,
316
z_dim=MODEL.z_dim,
317
num_classes=DATA.num_classes,
318
y_sampler="totally_random",
319
radius="N/A",
320
generator=generator,
321
discriminator=None,
322
is_train=True,
323
LOSS=LOSS,
324
RUN=RUN,
325
MODEL=MODEL,
326
is_stylegan=MODEL.backbone in ["stylegan2", "stylegan3"],
327
generator_mapping=None,
328
generator_synthesis=None,
329
style_mixing_p=0.0,
330
stylegan_update_emas=False,
331
device=device,
332
cal_trsp_cost=False)
333
generator.eval()
334
335
def define_sampler(dataset_name, dis_cond_mtd, batch_size, num_classes):
336
if dis_cond_mtd != "W/O":
337
if dataset_name == "CIFAR10" or batch_size >= num_classes*8:
338
sampler = "acending_all"
339
else:
340
sampler = "acending_some"
341
else:
342
sampler = "totally_random"
343
return sampler
344
345
def make_GAN_trainable(Gen, Gen_ema, Dis):
346
Gen.train()
347
Gen.apply(track_bn_statistics)
348
if Gen_ema is not None:
349
Gen_ema.train()
350
Gen_ema.apply(track_bn_statistics)
351
352
Dis.train()
353
Dis.apply(track_bn_statistics)
354
355
356
def make_GAN_untrainable(Gen, Gen_ema, Dis):
357
Gen.eval()
358
Gen.apply(set_deterministic_op_trainable)
359
if Gen_ema is not None:
360
Gen_ema.eval()
361
Gen_ema.apply(set_deterministic_op_trainable)
362
363
Dis.eval()
364
Dis.apply(set_deterministic_op_trainable)
365
366
367
def peel_models(Gen, Gen_ema, Dis):
368
if isinstance(Dis, DataParallel) or isinstance(Dis, DistributedDataParallel):
369
dis = Dis.module
370
else:
371
dis = Dis
372
373
if isinstance(Gen, DataParallel) or isinstance(Gen, DistributedDataParallel):
374
gen = Gen.module
375
else:
376
gen = Gen
377
378
if Gen_ema is not None:
379
if isinstance(Gen_ema, DataParallel) or isinstance(Gen_ema, DistributedDataParallel):
380
gen_ema = Gen_ema.module
381
else:
382
gen_ema = Gen_ema
383
else:
384
gen_ema = None
385
return gen, gen_ema, dis
386
387
388
def peel_model(model):
389
if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
390
model = model.module
391
return model
392
393
394
def save_model(model, when, step, ckpt_dir, states):
395
model_tpl = "model={model}-{when}-weights-step={step}.pth"
396
model_ckpt_list = glob.glob(join(ckpt_dir, model_tpl.format(model=model, when=when, step="*")))
397
if len(model_ckpt_list) > 0:
398
find_and_remove(model_ckpt_list[0])
399
400
torch.save(states, join(ckpt_dir, model_tpl.format(model=model, when=when, step=step)))
401
402
403
def save_model_c(states, mode, RUN):
404
ckpt_path = join(RUN.ckpt_dir, "model=C-{mode}-best-weights.pth".format(mode=mode))
405
torch.save(states, ckpt_path)
406
407
408
def find_string(list_, string):
409
for i, s in enumerate(list_):
410
if string == s:
411
return i
412
413
414
def find_and_remove(path):
415
if isfile(path):
416
os.remove(path)
417
418
419
def plot_img_canvas(images, save_path, num_cols, logger, logging=True):
420
if logger is None:
421
logging = False
422
directory = dirname(save_path)
423
424
if not exists(directory):
425
os.makedirs(directory)
426
427
save_image(((images + 1)/2).clamp(0.0, 1.0), save_path, padding=0, nrow=num_cols)
428
if logging:
429
logger.info("Save image canvas to {}".format(save_path))
430
431
432
def plot_spectrum_image(real_spectrum, fake_spectrum, directory, logger, logging=True):
433
if logger is None:
434
logging = False
435
436
if not exists(directory):
437
os.makedirs(directory)
438
439
save_path = join(directory, "dfft_spectrum.png")
440
441
fig = plt.figure()
442
ax1 = fig.add_subplot(121)
443
ax2 = fig.add_subplot(122)
444
445
ax1.imshow(real_spectrum, cmap="viridis")
446
ax1.set_title("Spectrum of real images")
447
448
ax2.imshow(fake_spectrum, cmap="viridis")
449
ax2.set_title("Spectrum of fake images")
450
fig.savefig(save_path)
451
if logging:
452
logger.info("Save image to {}".format(save_path))
453
454
455
def plot_tsne_scatter_plot(df, tsne_results, flag, directory, logger, logging=True):
456
if logger is None:
457
logging = False
458
459
if not exists(directory):
460
os.makedirs(directory)
461
462
save_path = join(directory, "tsne_scatter_{flag}.png".format(flag=flag))
463
464
df["tsne-2d-one"] = tsne_results[:, 0]
465
df["tsne-2d-two"] = tsne_results[:, 1]
466
plt.figure(figsize=(16, 10))
467
sns.scatterplot(x="tsne-2d-one",
468
y="tsne-2d-two",
469
hue="labels",
470
palette=sns.color_palette("hls", 10),
471
data=df,
472
legend="full",
473
alpha=0.5).legend(fontsize=15, loc="upper right")
474
plt.title("TSNE result of {flag} images".format(flag=flag), fontsize=25)
475
plt.xlabel("", fontsize=7)
476
plt.ylabel("", fontsize=7)
477
plt.savefig(save_path)
478
if logging:
479
logger.info("Save image to {path}".format(path=save_path))
480
481
482
def save_images_png(data_loader, generator, discriminator, is_generate, num_images, y_sampler, batch_size, z_prior,
483
truncation_factor, z_dim, num_classes, LOSS, OPTIMIZATION, RUN, MODEL, is_stylegan, generator_mapping,
484
generator_synthesis, directory, device):
485
num_batches = math.ceil(float(num_images) / float(batch_size))
486
if RUN.distributed_data_parallel: num_batches = num_batches//OPTIMIZATION.world_size + 1
487
if is_generate:
488
image_type = "fake"
489
else:
490
image_type = "real"
491
data_iter = iter(data_loader)
492
493
print("Save {num_images} {image_type} images in png format.".format(num_images=num_images, image_type=image_type))
494
495
directory = join(directory, image_type)
496
if exists(directory):
497
shutil.rmtree(directory)
498
os.makedirs(directory)
499
for f in range(num_classes):
500
os.makedirs(join(directory, str(f)))
501
502
with torch.no_grad() if not LOSS.apply_lo else dummy_context_mgr() as mpc:
503
for i in tqdm(range(0, num_batches), disable=False):
504
start = i * batch_size
505
end = start + batch_size
506
if is_generate:
507
images, labels, _, _, _, _, _= sample.generate_images(z_prior=z_prior,
508
truncation_factor=truncation_factor,
509
batch_size=batch_size,
510
z_dim=z_dim,
511
num_classes=num_classes,
512
y_sampler=y_sampler,
513
radius="N/A",
514
generator=generator,
515
discriminator=discriminator,
516
is_train=False,
517
LOSS=LOSS,
518
RUN=RUN,
519
MODEL=MODEL,
520
is_stylegan=is_stylegan,
521
generator_mapping=generator_mapping,
522
generator_synthesis=generator_synthesis,
523
style_mixing_p=0.0,
524
stylegan_update_emas=False,
525
device=device,
526
cal_trsp_cost=False)
527
else:
528
try:
529
images, labels = next(data_iter)
530
except StopIteration:
531
break
532
533
for idx, img in enumerate(images.detach()):
534
if batch_size * i + idx < num_images:
535
save_image(((img+1)/2).clamp(0.0, 1.0),
536
join(directory, str(labels[idx].item()), "{idx}.png".format(idx=batch_size * i + idx)))
537
else:
538
pass
539
540
print("Finish saving png images to {directory}/*/*.png".format(directory=directory))
541
542
543
def orthogonalize_model(model, strength=1e-4, blacklist=[]):
544
with torch.no_grad():
545
for param in model.parameters():
546
if len(param.shape) < 2 or any([param is item for item in blacklist]):
547
continue
548
w = param.view(param.shape[0], -1)
549
grad = (2 * torch.mm(torch.mm(w, w.t()) * (1. - torch.eye(w.shape[0], device=w.device)), w))
550
param.grad.data += strength * grad.view(param.shape)
551
552
553
def interpolate(x0, x1, num_midpoints):
554
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device="cuda").to(x0.dtype)
555
return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
556
557
558
def accm_values_convert_dict(list_dict, value_dict, step, interval):
559
for name, value_list in list_dict.items():
560
if step is None:
561
value_list += [value_dict[name]]
562
else:
563
try:
564
value_list[step // interval - 1] = value_dict[name]
565
except IndexError:
566
try:
567
value_list += [value_dict[name]]
568
except:
569
raise KeyError
570
list_dict[name] = value_list
571
return list_dict
572
573
574
def save_dict_npy(directory, name, dictionary):
575
if not exists(directory):
576
os.makedirs(directory)
577
578
save_path = join(directory, name + ".npy")
579
np.save(save_path, dictionary)
580
581
582
def load_ImageNet_label_dict(data_name, is_torch_backbone):
583
if data_name in ["Baby_ImageNet", "Papa_ImageNet", "Grandpa_ImageNet"] and is_torch_backbone:
584
with open("./src/utils/pytorch_imagenet_folder_label_pairs.json", "r") as f:
585
ImageNet_folder_label_dict = json.load(f)
586
else:
587
label_table = open("./src/utils/tf_imagenet_folder_label_pairs.txt", 'r')
588
ImageNet_folder_label_dict, label = {}, 0
589
while True:
590
line = label_table.readline()
591
if not line: break
592
folder = line.split(' ')[0]
593
ImageNet_folder_label_dict[folder] = label
594
label += 1
595
return ImageNet_folder_label_dict
596
597
598
def compute_gradient(fx, logits, label, num_classes):
599
probs = torch.nn.Softmax(dim=1)(logits.detach().cpu())
600
gt_prob = F.one_hot(label, num_classes)
601
oneMp = gt_prob - probs
602
preds = (probs*gt_prob).sum(-1)
603
grad = torch.mean(fx.unsqueeze(1) * oneMp.unsqueeze(2), dim=0)
604
return fx.norm(dim=1), preds, torch.norm(grad, dim=1)
605
606
607
def load_parameters(src, dst, strict=True):
608
mismatch_names = []
609
for dst_key, dst_value in dst.items():
610
if dst_key in src:
611
if dst_value.shape == src[dst_key].shape:
612
dst[dst_key].copy_(src[dst_key])
613
else:
614
mismatch_names.append(dst_key)
615
err = "source tensor {key}({src}) does not match with destination tensor {key}({dst}).".\
616
format(key=dst_key, src=src[dst_key].shape, dst=dst_value.shape)
617
assert not strict, err
618
else:
619
mismatch_names.append(dst_key)
620
assert not strict, "dst_key is not in src_dict."
621
return mismatch_names
622
623
624
def enable_allreduce(dict_):
625
loss = 0
626
for key, value in dict_.items():
627
if value is not None and key != "label":
628
loss += value.mean()*0
629
return loss
630
631
632
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
633
if os.path.isfile(pretrained_weights):
634
state_dict = torch.load(pretrained_weights, map_location="cpu")
635
if checkpoint_key is not None and checkpoint_key in state_dict:
636
print(f"Take key {checkpoint_key} in provided checkpoint dict")
637
state_dict = state_dict[checkpoint_key]
638
# remove `module.` prefix
639
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
640
# remove `backbone.` prefix induced by multicrop wrapper
641
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
642
msg = model.load_state_dict(state_dict, strict=False)
643
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
644
else:
645
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
646
url = None
647
if model_name == "vit_small" and patch_size == 16:
648
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
649
elif model_name == "vit_small" and patch_size == 8:
650
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
651
elif model_name == "vit_base" and patch_size == 16:
652
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
653
elif model_name == "vit_base" and patch_size == 8:
654
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
655
elif model_name == "xcit_small_12_p16":
656
url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
657
elif model_name == "xcit_small_12_p8":
658
url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
659
elif model_name == "xcit_medium_24_p16":
660
url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
661
elif model_name == "xcit_medium_24_p8":
662
url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
663
elif model_name == "resnet50":
664
url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
665
if url is not None:
666
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
667
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
668
model.load_state_dict(state_dict, strict=False)
669
else:
670
print("There is no reference weights available for this model => We use random weights.")
671
672
673
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
674
url = None
675
if model_name == "vit_small" and patch_size == 16:
676
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
677
elif model_name == "vit_small" and patch_size == 8:
678
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
679
elif model_name == "vit_base" and patch_size == 16:
680
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
681
elif model_name == "vit_base" and patch_size == 8:
682
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
683
elif model_name == "resnet50":
684
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
685
if url is not None:
686
print("We load the reference pretrained linear weights.")
687
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
688
state_dict = {k.replace("module.linear.", ""): v for k, v in state_dict.items()}
689
linear_classifier.load_state_dict(state_dict, strict=True)
690
else:
691
print("We use random linear weights.")
692
693