Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/cog_predict.py
884 views
1
# flake8: noqa
2
# This file is used for deploying replicate models
3
# running: cog predict -i img=@inputs/whole_imgs/10045.png -i version='v1.4' -i scale=2
4
# push: cog push r8.im/tencentarc/gfpgan
5
# push (backup): cog push r8.im/xinntao/gfpgan
6
7
import os
8
9
os.system('python setup.py develop')
10
os.system('pip install realesrgan')
11
12
import cv2
13
import shutil
14
import tempfile
15
import torch
16
from basicsr.archs.srvgg_arch import SRVGGNetCompact
17
18
from gfpgan import GFPGANer
19
20
try:
21
from cog import BasePredictor, Input, Path
22
from realesrgan.utils import RealESRGANer
23
except Exception:
24
print('please install cog and realesrgan package')
25
26
27
class Predictor(BasePredictor):
28
29
def setup(self):
30
os.makedirs('output', exist_ok=True)
31
# download weights
32
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
33
os.system(
34
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./gfpgan/weights'
35
)
36
if not os.path.exists('gfpgan/weights/GFPGANv1.2.pth'):
37
os.system(
38
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./gfpgan/weights')
39
if not os.path.exists('gfpgan/weights/GFPGANv1.3.pth'):
40
os.system(
41
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./gfpgan/weights')
42
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
43
os.system(
44
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
45
if not os.path.exists('gfpgan/weights/RestoreFormer.pth'):
46
os.system(
47
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights'
48
)
49
50
# background enhancer with RealESRGAN
51
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
52
model_path = 'gfpgan/weights/realesr-general-x4v3.pth'
53
half = True if torch.cuda.is_available() else False
54
self.upsampler = RealESRGANer(
55
scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
56
57
# Use GFPGAN for face enhancement
58
self.face_enhancer = GFPGANer(
59
model_path='gfpgan/weights/GFPGANv1.4.pth',
60
upscale=2,
61
arch='clean',
62
channel_multiplier=2,
63
bg_upsampler=self.upsampler)
64
self.current_version = 'v1.4'
65
66
def predict(
67
self,
68
img: Path = Input(description='Input'),
69
version: str = Input(
70
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
71
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
72
default='v1.4'),
73
scale: float = Input(description='Rescaling factor', default=2),
74
) -> Path:
75
weight = 0.5
76
print(img, version, scale, weight)
77
try:
78
extension = os.path.splitext(os.path.basename(str(img)))[1]
79
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
80
if len(img.shape) == 3 and img.shape[2] == 4:
81
img_mode = 'RGBA'
82
elif len(img.shape) == 2:
83
img_mode = None
84
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
85
else:
86
img_mode = None
87
88
h, w = img.shape[0:2]
89
if h < 300:
90
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
91
92
if self.current_version != version:
93
if version == 'v1.2':
94
self.face_enhancer = GFPGANer(
95
model_path='gfpgan/weights/GFPGANv1.2.pth',
96
upscale=2,
97
arch='clean',
98
channel_multiplier=2,
99
bg_upsampler=self.upsampler)
100
self.current_version = 'v1.2'
101
elif version == 'v1.3':
102
self.face_enhancer = GFPGANer(
103
model_path='gfpgan/weights/GFPGANv1.3.pth',
104
upscale=2,
105
arch='clean',
106
channel_multiplier=2,
107
bg_upsampler=self.upsampler)
108
self.current_version = 'v1.3'
109
elif version == 'v1.4':
110
self.face_enhancer = GFPGANer(
111
model_path='gfpgan/weights/GFPGANv1.4.pth',
112
upscale=2,
113
arch='clean',
114
channel_multiplier=2,
115
bg_upsampler=self.upsampler)
116
self.current_version = 'v1.4'
117
elif version == 'RestoreFormer':
118
self.face_enhancer = GFPGANer(
119
model_path='gfpgan/weights/RestoreFormer.pth',
120
upscale=2,
121
arch='RestoreFormer',
122
channel_multiplier=2,
123
bg_upsampler=self.upsampler)
124
125
try:
126
_, _, output = self.face_enhancer.enhance(
127
img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
128
except RuntimeError as error:
129
print('Error', error)
130
131
try:
132
if scale != 2:
133
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
134
h, w = img.shape[0:2]
135
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
136
except Exception as error:
137
print('wrong scale input.', error)
138
139
if img_mode == 'RGBA': # RGBA images should be saved in png format
140
extension = 'png'
141
# save_path = f'output/out.{extension}'
142
# cv2.imwrite(save_path, output)
143
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
144
cv2.imwrite(str(out_path), output)
145
except Exception as error:
146
print('global exception: ', error)
147
finally:
148
clean_folder('output')
149
return out_path
150
151
152
def clean_folder(folder):
153
for filename in os.listdir(folder):
154
file_path = os.path.join(folder, filename)
155
try:
156
if os.path.isfile(file_path) or os.path.islink(file_path):
157
os.unlink(file_path)
158
elif os.path.isdir(file_path):
159
shutil.rmtree(file_path)
160
except Exception as e:
161
print(f'Failed to delete {file_path}. Reason: {e}')
162
163