Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TencentARC
GitHub Repository: TencentARC/GFPGAN
Path: blob/master/tests/test_arcface_arch.py
884 views
1
import torch
2
3
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
4
5
6
def test_resnetarcface():
7
"""Test arch: ResNetArcFace."""
8
9
# model init and forward (gpu)
10
if torch.cuda.is_available():
11
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
12
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
13
output = net(img)
14
assert output.shape == (1, 512)
15
16
# -------------------- without SE block ----------------------- #
17
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
18
output = net(img)
19
assert output.shape == (1, 512)
20
21
22
def test_basicblock():
23
"""Test the BasicBlock in arcface_arch"""
24
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
25
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
26
output = block(img)
27
assert output.shape == (1, 3, 12, 12)
28
29
# ----------------- use the downsmaple module--------------- #
30
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
31
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
32
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
33
output = block(img)
34
assert output.shape == (1, 3, 6, 6)
35
36
37
def test_bottleneck():
38
"""Test the Bottleneck in arcface_arch"""
39
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
40
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
41
output = block(img)
42
assert output.shape == (1, 4, 12, 12)
43
44
# ----------------- use the downsmaple module--------------- #
45
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
46
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
47
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
48
output = block(img)
49
assert output.shape == (1, 4, 6, 6)
50
51