Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/tests/test_gfpgan_model.py
884 views
1
import tempfile
2
import torch
3
import yaml
4
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
5
from basicsr.data.paired_image_dataset import PairedImageDataset
6
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
7
8
from gfpgan.archs.arcface_arch import ResNetArcFace
9
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
10
from gfpgan.models.gfpgan_model import GFPGANModel
11
12
13
def test_gfpgan_model():
14
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
15
opt = yaml.load(f, Loader=yaml.FullLoader)
16
17
# build model
18
model = GFPGANModel(opt)
19
# test attributes
20
assert model.__class__.__name__ == 'GFPGANModel'
21
assert isinstance(model.net_g, GFPGANv1) # generator
22
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
23
# facial component discriminators
24
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
25
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
26
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
27
# identity network
28
assert isinstance(model.network_identity, ResNetArcFace)
29
# losses
30
assert isinstance(model.cri_pix, L1Loss)
31
assert isinstance(model.cri_perceptual, PerceptualLoss)
32
assert isinstance(model.cri_gan, GANLoss)
33
assert isinstance(model.cri_l1, L1Loss)
34
# optimizer
35
assert isinstance(model.optimizers[0], torch.optim.Adam)
36
assert isinstance(model.optimizers[1], torch.optim.Adam)
37
38
# prepare data
39
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
40
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
41
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
42
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
43
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
44
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
45
model.feed_data(data)
46
# check data shape
47
assert model.lq.shape == (1, 3, 512, 512)
48
assert model.gt.shape == (1, 3, 512, 512)
49
assert model.loc_left_eyes.shape == (1, 4)
50
assert model.loc_right_eyes.shape == (1, 4)
51
assert model.loc_mouths.shape == (1, 4)
52
53
# ----------------- test optimize_parameters -------------------- #
54
model.feed_data(data)
55
model.optimize_parameters(1)
56
assert model.output.shape == (1, 3, 512, 512)
57
assert isinstance(model.log_dict, dict)
58
# check returned keys
59
expected_keys = [
60
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
61
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
62
'l_d_right_eye', 'l_d_mouth'
63
]
64
assert set(expected_keys).issubset(set(model.log_dict.keys()))
65
66
# ----------------- remove pyramid_loss_weight-------------------- #
67
model.feed_data(data)
68
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
69
assert model.output.shape == (1, 3, 512, 512)
70
assert isinstance(model.log_dict, dict)
71
# check returned keys
72
expected_keys = [
73
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
74
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
75
'l_d_right_eye', 'l_d_mouth'
76
]
77
assert set(expected_keys).issubset(set(model.log_dict.keys()))
78
79
# ----------------- test save -------------------- #
80
with tempfile.TemporaryDirectory() as tmpdir:
81
model.opt['path']['models'] = tmpdir
82
model.opt['path']['training_states'] = tmpdir
83
model.save(0, 1)
84
85
# ----------------- test the test function -------------------- #
86
model.test()
87
assert model.output.shape == (1, 3, 512, 512)
88
# delete net_g_ema
89
model.__delattr__('net_g_ema')
90
model.test()
91
assert model.output.shape == (1, 3, 512, 512)
92
assert model.net_g.training is True # should back to training mode after testing
93
94
# ----------------- test nondist_validation -------------------- #
95
# construct dataloader
96
dataset_opt = dict(
97
name='Demo',
98
dataroot_gt='tests/data/gt',
99
dataroot_lq='tests/data/gt',
100
io_backend=dict(type='disk'),
101
scale=4,
102
phase='val')
103
dataset = PairedImageDataset(dataset_opt)
104
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
105
assert model.is_train is True
106
with tempfile.TemporaryDirectory() as tmpdir:
107
model.opt['path']['visualization'] = tmpdir
108
model.nondist_validation(dataloader, 1, None, save_img=True)
109
assert model.is_train is True
110
# check metric_results
111
assert 'psnr' in model.metric_results
112
assert isinstance(model.metric_results['psnr'], float)
113
114
# validation
115
with tempfile.TemporaryDirectory() as tmpdir:
116
model.opt['is_train'] = False
117
model.opt['val']['suffix'] = 'test'
118
model.opt['path']['visualization'] = tmpdir
119
model.opt['val']['pbar'] = True
120
model.nondist_validation(dataloader, 1, None, save_img=True)
121
# check metric_results
122
assert 'psnr' in model.metric_results
123
assert isinstance(model.metric_results['psnr'], float)
124
125
# if opt['val']['suffix'] is None
126
model.opt['val']['suffix'] = None
127
model.opt['name'] = 'demo'
128
model.opt['path']['visualization'] = tmpdir
129
model.nondist_validation(dataloader, 1, None, save_img=True)
130
# check metric_results
131
assert 'psnr' in model.metric_results
132
assert isinstance(model.metric_results['psnr'], float)
133
134