Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/demo.py
3118 views
1
import argparse
2
import os
3
4
import cv2
5
import numpy as np
6
import torch
7
from dataloader import PredDataset
8
from networks.models import build_model
9
from networks.transforms import (
10
groupnorm_normalise_image,
11
trimap_transform,
12
)
13
from tqdm import tqdm
14
15
16
def np_to_torch(x):
17
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float()
18
19
20
def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
21
""" Scales inputs to multiple of 8. """
22
h, w = x.shape[:2]
23
h1 = int(np.ceil(scale * h / 8) * 8)
24
w1 = int(np.ceil(scale * w / 8) * 8)
25
x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
26
return x_scale
27
28
29
def swap_bg(image, alpha):
30
green_bg = np.zeros_like(image).astype(np.float32)
31
green_bg[:, :, 1] = 255
32
33
alpha = alpha[:, :, np.newaxis]
34
result = alpha * image.astype(np.float32) + (1 - alpha) * green_bg
35
result = np.clip(result, 0, 255).astype(np.uint8)
36
37
return result
38
39
40
def predict_fba_folder(model, args):
41
save_dir = args.output_dir
42
os.makedirs(save_dir, exist_ok=True)
43
44
dataset_test = PredDataset(args.image_dir, args.trimap_dir)
45
46
gen = iter(dataset_test)
47
for item_dict in tqdm(gen):
48
image_np = item_dict["image"]
49
trimap_np = item_dict["trimap"]
50
51
fg, bg, alpha = pred(image_np, trimap_np, model, args)
52
53
cv2.imwrite(
54
os.path.join(save_dir, item_dict["name"][:-4] + "_fg.png"),
55
fg[:, :, ::-1] * 255,
56
)
57
cv2.imwrite(
58
os.path.join(save_dir, item_dict["name"][:-4] + "_bg.png"),
59
bg[:, :, ::-1] * 255,
60
)
61
cv2.imwrite(
62
os.path.join(save_dir, item_dict["name"][:-4] + "_alpha.png"), alpha * 255,
63
)
64
65
example_swap_bg = swap_bg(fg[:, :, ::-1] * 255, alpha)
66
cv2.imwrite(
67
os.path.join(save_dir, item_dict["name"][:-4] + "_swapped_bg.png"), example_swap_bg,
68
)
69
70
71
def pred(image_np: np.ndarray, trimap_np: np.ndarray, model, args) -> np.ndarray:
72
""" Predict alpha, foreground and background.
73
Parameters:
74
image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
75
trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
76
Returns:
77
fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
78
bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
79
alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
80
"""
81
h, w = trimap_np.shape[:2]
82
83
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
84
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
85
86
with torch.no_grad():
87
88
image_torch = np_to_torch(image_scale_np).to(args.device)
89
trimap_torch = np_to_torch(trimap_scale_np).to(args.device)
90
91
trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np)).to(
92
args.device,
93
)
94
image_transformed_torch = groupnorm_normalise_image(
95
image_torch.clone(), format="nchw",
96
)
97
98
output = model(
99
image_torch,
100
trimap_torch,
101
image_transformed_torch,
102
trimap_transformed_torch,
103
)
104
105
output = cv2.resize(
106
output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4,
107
)
108
alpha = output[:, :, 0]
109
fg = output[:, :, 1:4]
110
bg = output[:, :, 4:7]
111
112
alpha[trimap_np[:, :, 0] == 1] = 0
113
alpha[trimap_np[:, :, 1] == 1] = 1
114
fg[alpha == 1] = image_np[alpha == 1]
115
bg[alpha == 0] = image_np[alpha == 0]
116
return fg, bg, alpha
117
118
119
if __name__ == "__main__":
120
121
parser = argparse.ArgumentParser()
122
# Model related arguments
123
parser.add_argument("--encoder", default="resnet50_GN_WS", help="Encoder model")
124
parser.add_argument("--decoder", default="fba_decoder", help="Decoder model")
125
parser.add_argument("--weights", default="FBA.pth")
126
parser.add_argument("--image_dir", default="./examples/images", help="")
127
parser.add_argument(
128
"--trimap_dir", default="./examples/trimaps", help="",
129
)
130
parser.add_argument("--output_dir", default="./examples/predictions", help="")
131
parser.add_argument("--device", default="cpu", help="Device for inference on")
132
133
args = parser.parse_args()
134
model = build_model(args).to(args.device)
135
model.eval()
136
predict_fba_folder(model, args)
137
138