Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/generate_trimaps.py
3118 views
1
import os
2
import argparse
3
import torch
4
import numpy as np
5
from torchvision import transforms
6
import cv2
7
8
IMG_EXT = ('.png', '.jpg', '.jpeg', '.JPG', '.JPEG')
9
10
CLASS_MAP = {"background": 0, "aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4, "bottle": 5, "bus": 6, "car": 7,
11
"cat": 8, "chair": 9, "cow": 10, "diningtable": 11, "dog": 12, "horse": 13, "motorbike": 14, "person": 15,
12
"potted plant": 16, "sheep": 17, "sofa": 18, "train": 19, "tv/monitor": 20}
13
14
15
def trimap(probs, size, conf_threshold):
16
"""
17
This function creates a trimap based on simple dilation algorithm
18
Inputs [3]: an image with probabilities of each pixel being the foreground, size of dilation kernel,
19
foreground confidence threshold
20
Output : a trimap
21
"""
22
mask = (probs > 0.05).astype(np.uint8) * 255
23
24
pixels = 2 * size + 1
25
kernel = np.ones((pixels, pixels), np.uint8)
26
27
dilation = cv2.dilate(mask, kernel, iterations=1)
28
29
remake = np.zeros_like(mask)
30
remake[dilation == 255] = 127 # Set every pixel within dilated region as probably foreground.
31
remake[probs > conf_threshold] = 255 # Set every pixel with large enough probability as definitely foreground.
32
33
return remake
34
35
36
def parse_args():
37
parser = argparse.ArgumentParser(description="Deeplab Segmentation")
38
parser.add_argument(
39
"-i",
40
"--input_dir",
41
type=str,
42
required=True,
43
help="Directory to save the output results. (required)",
44
)
45
parser.add_argument(
46
"--target_class",
47
type=str,
48
default='person',
49
choices=CLASS_MAP.keys(),
50
help="Type of the foreground object.",
51
)
52
parser.add_argument(
53
"--show",
54
action='store_true',
55
help="Use to show results.",
56
)
57
parser.add_argument(
58
"--conf_threshold",
59
type=float,
60
default='0.95',
61
help="Confidence threshold for the foreground object. "
62
"You can play with it to get better looking trimaps.",
63
)
64
65
args = parser.parse_args()
66
return args
67
68
69
def main(input_dir, target_class, show, conf_threshold):
70
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
71
model.eval()
72
73
trimaps_path = os.path.join(input_dir, "trimaps")
74
os.makedirs(trimaps_path, exist_ok=True)
75
76
images_list = os.listdir(input_dir)
77
for filename in images_list:
78
if not filename.endswith(IMG_EXT):
79
continue
80
input_image = cv2.imread(os.path.join(input_dir, filename))
81
original_image = input_image.copy()
82
83
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
84
85
preprocess = transforms.Compose([
86
transforms.ToTensor(),
87
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
88
])
89
90
input_tensor = preprocess(input_image)
91
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
92
93
with torch.no_grad():
94
output = model(input_batch)['out'][0]
95
output = torch.softmax(output, 0)
96
97
output_cat = output[CLASS_MAP[target_class], ...].numpy()
98
99
trimap_image = trimap(output_cat, 7, conf_threshold)
100
trimap_filename = os.path.basename(filename).split('.')[0] + '.png'
101
cv2.imwrite(os.path.join(trimaps_path, trimap_filename), trimap_image)
102
103
if show:
104
cv2.imshow('mask', output_cat)
105
cv2.imshow('image', original_image)
106
cv2.imshow('trimap', trimap_image)
107
cv2.waitKey(0)
108
109
110
if __name__ == "__main__":
111
args = parse_args()
112
main(args.input_dir, args.target_class, args.show, args.conf_threshold)
113
114