Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/FBAMatting/dataloader.py
3118 views
1
import os
2
3
import cv2
4
import numpy as np
5
from torch.utils.data import Dataset
6
7
8
class PredDataset(Dataset):
9
""" Reads image and trimap pairs from folder.
10
11
"""
12
13
def __init__(self, img_dir, trimap_dir):
14
self.img_dir, self.trimap_dir = img_dir, trimap_dir
15
self.img_names = [
16
x
17
for x in os.listdir(self.img_dir)
18
if os.path.isfile(os.path.join(self.img_dir, x))
19
]
20
21
def __len__(self):
22
return len(self.img_names)
23
24
def __getitem__(self, idx):
25
img_name = self.img_names[idx]
26
trimap_name = img_name[:-3] + "png"
27
28
image = read_image(os.path.join(self.img_dir, img_name))
29
trimap = read_trimap(os.path.join(self.trimap_dir, trimap_name))
30
pred_dict = {"image": image, "trimap": trimap, "name": img_name}
31
32
return pred_dict
33
34
35
def read_image(name):
36
return (cv2.imread(name) / 255.0)[:, :, ::-1]
37
38
39
def read_trimap(name):
40
trimap_im = cv2.imread(name, 0) / 255.0
41
h, w = trimap_im.shape
42
trimap = np.zeros((h, w, 2))
43
trimap[trimap_im == 1, 1] = 1
44
trimap[trimap_im == 0, 0] = 1
45
return trimap
46
47