Path: blob/master/modules/python/test/test_grabcut.py
16337 views
#!/usr/bin/env python1'''2===============================================================================3Interactive Image Segmentation using GrabCut algorithm.4===============================================================================5'''67# Python 2/3 compatibility8from __future__ import print_function910import numpy as np11import cv2 as cv12import sys1314from tests_common import NewOpenCVTests1516class grabcut_test(NewOpenCVTests):1718def verify(self, mask, exp):1920maxDiffRatio = 0.0221expArea = np.count_nonzero(exp)22nonIntersectArea = np.count_nonzero(mask != exp)23curRatio = float(nonIntersectArea) / expArea24return curRatio < maxDiffRatio2526def scaleMask(self, mask):2728return np.where((mask==cv.GC_FGD) + (mask==cv.GC_PR_FGD),255,0).astype('uint8')2930def test_grabcut(self):3132img = self.get_sample('cv/shared/airplane.png')33mask_prob = self.get_sample("cv/grabcut/mask_probpy.png", 0)34exp_mask1 = self.get_sample("cv/grabcut/exp_mask1py.png", 0)35exp_mask2 = self.get_sample("cv/grabcut/exp_mask2py.png", 0)3637if img is None:38self.assertTrue(False, 'Missing test data')3940rect = (24, 126, 459, 168)41mask = np.zeros(img.shape[:2], dtype = np.uint8)42bgdModel = np.zeros((1,65),np.float64)43fgdModel = np.zeros((1,65),np.float64)44cv.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv.GC_INIT_WITH_RECT)45cv.grabCut(img, mask, rect, bgdModel, fgdModel, 2, cv.GC_EVAL)4647if mask_prob is None:48mask_prob = mask.copy()49cv.imwrite(self.extraTestDataPath + '/cv/grabcut/mask_probpy.png', mask_prob)50if exp_mask1 is None:51exp_mask1 = self.scaleMask(mask)52cv.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask1py.png', exp_mask1)5354self.assertEqual(self.verify(self.scaleMask(mask), exp_mask1), True)5556mask = mask_prob57bgdModel = np.zeros((1,65),np.float64)58fgdModel = np.zeros((1,65),np.float64)59cv.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv.GC_INIT_WITH_MASK)60cv.grabCut(img, mask, rect, bgdModel, fgdModel, 1, cv.GC_EVAL)6162if exp_mask2 is None:63exp_mask2 = self.scaleMask(mask)64cv.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask2py.png', exp_mask2)6566self.assertEqual(self.verify(self.scaleMask(mask), exp_mask2), True)676869if __name__ == '__main__':70NewOpenCVTests.bootstrap()717273