Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/gfpgan/utils.py
884 views
1
import cv2
2
import os
3
import torch
4
from basicsr.utils import img2tensor, tensor2img
5
from basicsr.utils.download_util import load_file_from_url
6
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
7
from torchvision.transforms.functional import normalize
8
9
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
10
from gfpgan.archs.gfpganv1_arch import GFPGANv1
11
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
12
13
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
15
16
class GFPGANer():
17
"""Helper for restoration with GFPGAN.
18
19
It will detect and crop faces, and then resize the faces to 512x512.
20
GFPGAN is used to restored the resized faces.
21
The background is upsampled with the bg_upsampler.
22
Finally, the faces will be pasted back to the upsample background image.
23
24
Args:
25
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
26
upscale (float): The upscale of the final output. Default: 2.
27
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
28
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
30
"""
31
32
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
33
self.upscale = upscale
34
self.bg_upsampler = bg_upsampler
35
36
# initialize model
37
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
38
# initialize the GFP-GAN
39
if arch == 'clean':
40
self.gfpgan = GFPGANv1Clean(
41
out_size=512,
42
num_style_feat=512,
43
channel_multiplier=channel_multiplier,
44
decoder_load_path=None,
45
fix_decoder=False,
46
num_mlp=8,
47
input_is_latent=True,
48
different_w=True,
49
narrow=1,
50
sft_half=True)
51
elif arch == 'bilinear':
52
self.gfpgan = GFPGANBilinear(
53
out_size=512,
54
num_style_feat=512,
55
channel_multiplier=channel_multiplier,
56
decoder_load_path=None,
57
fix_decoder=False,
58
num_mlp=8,
59
input_is_latent=True,
60
different_w=True,
61
narrow=1,
62
sft_half=True)
63
elif arch == 'original':
64
self.gfpgan = GFPGANv1(
65
out_size=512,
66
num_style_feat=512,
67
channel_multiplier=channel_multiplier,
68
decoder_load_path=None,
69
fix_decoder=True,
70
num_mlp=8,
71
input_is_latent=True,
72
different_w=True,
73
narrow=1,
74
sft_half=True)
75
elif arch == 'RestoreFormer':
76
from gfpgan.archs.restoreformer_arch import RestoreFormer
77
self.gfpgan = RestoreFormer()
78
# initialize face helper
79
self.face_helper = FaceRestoreHelper(
80
upscale,
81
face_size=512,
82
crop_ratio=(1, 1),
83
det_model='retinaface_resnet50',
84
save_ext='png',
85
use_parse=True,
86
device=self.device,
87
model_rootpath='gfpgan/weights')
88
89
if model_path.startswith('https://'):
90
model_path = load_file_from_url(
91
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
92
loadnet = torch.load(model_path)
93
if 'params_ema' in loadnet:
94
keyname = 'params_ema'
95
else:
96
keyname = 'params'
97
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
98
self.gfpgan.eval()
99
self.gfpgan = self.gfpgan.to(self.device)
100
101
@torch.no_grad()
102
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
103
self.face_helper.clean_all()
104
105
if has_aligned: # the inputs are already aligned
106
img = cv2.resize(img, (512, 512))
107
self.face_helper.cropped_faces = [img]
108
else:
109
self.face_helper.read_image(img)
110
# get face landmarks for each face
111
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
112
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
113
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
114
# align and warp each face
115
self.face_helper.align_warp_face()
116
117
# face restoration
118
for cropped_face in self.face_helper.cropped_faces:
119
# prepare data
120
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
121
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
122
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
123
124
try:
125
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
126
# convert to image
127
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
128
except RuntimeError as error:
129
print(f'\tFailed inference for GFPGAN: {error}.')
130
restored_face = cropped_face
131
132
restored_face = restored_face.astype('uint8')
133
self.face_helper.add_restored_face(restored_face)
134
135
if not has_aligned and paste_back:
136
# upsample the background
137
if self.bg_upsampler is not None:
138
# Now only support RealESRGAN for upsampling background
139
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
140
else:
141
bg_img = None
142
143
self.face_helper.get_inverse_affine(None)
144
# paste each restored face to the input image
145
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
146
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
147
else:
148
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
149
150