Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
dbolya
GitHub Repository: dbolya/tide
Path: blob/master/tidecv/errors/error.py
110 views
1
from typing import Union
2
3
import cv2
4
from .. import functions as f
5
6
class Error:
7
""" A base class for all error types. """
8
9
def fix(self) -> Union[tuple, None]:
10
"""
11
Returns a fixed version of the AP data point for this error or
12
None if this error should be suppressed.
13
14
Return type is:
15
class:int, (score:float, is_positive:bool, info:dict)
16
"""
17
raise NotImplementedError
18
19
def unfix(self) -> Union[tuple, None]:
20
""" Returns the original version of this data point. """
21
22
if hasattr(self, 'pred'):
23
# If an ignored instance is an error, it's not in the data point list, so there's no "unfixed" entry
24
if self.pred['used'] is None: return None
25
else: return self.pred['class'], (self.pred['score'], False, self.pred['info'])
26
else:
27
return None
28
29
def get_id(self) -> int:
30
if hasattr(self, 'pred'):
31
return self.pred['_id']
32
elif hasattr(self, 'gt'):
33
return self.gt['_id']
34
else:
35
return -1
36
37
38
def show(self, dataset, out_path:str=None,
39
pred_color:tuple=(43, 12, 183), gt_color:tuple=(43, 183, 12),
40
font=cv2.FONT_HERSHEY_SIMPLEX):
41
42
pred = self.pred if hasattr(self, 'pred') else self.gt
43
img = dataset.get_img_with_anns(pred['image_id'])
44
45
46
if hasattr(self, 'gt'):
47
img = cv2.rectangle(img, *f.points(self.gt['bbox']), gt_color, 2)
48
img = cv2.putText(img, dataset.cat_name(self.gt['category_id']),
49
(100, 200), font, 1, gt_color, 2, cv2.LINE_AA, False)
50
51
if hasattr(self, 'pred'):
52
img = cv2.rectangle(img, *f.points(pred['bbox']), pred_color, 2)
53
img = cv2.putText(img, '%s (%.2f)' % (dataset.cat_name(pred['category_id']), pred['score']),
54
(100, 100), font, 1, pred_color, 2, cv2.LINE_AA, False)
55
56
if out_path is None:
57
cv2.imshow(self.short_name, img)
58
cv2.moveWindow(self.short_name, 100, 100)
59
60
cv2.waitKey()
61
cv2.destroyAllWindows()
62
else:
63
cv2.imwrite(out_path, img)
64
65
def get_info(self, dataset):
66
info = {}
67
info['type'] = self.short_name
68
69
if hasattr(self, 'gt'):
70
info['gt'] = self.gt
71
if hasattr(self, 'pred'):
72
info['pred'] = self.pred
73
74
img_id = (self.pred if hasattr(self, 'pred') else self.gt)['image_id']
75
info['all_gt'] = dataset.get(img_id)
76
info['img'] = dataset.get_img(img_id)
77
78
return info
79
80
81
82
83
84
85
86
87
class BestGTMatch:
88
"""
89
Some errors are fixed by changing false positives to true positives.
90
The issue with fixing these errors naively is that you might have
91
multiple errors attempting to fix the same GT. In that case, we need
92
to select which error actually gets fixed, and which others just get
93
suppressed (since we can only fix one error per GT).
94
95
To address this, this class finds the prediction with the hiighest
96
score and then uses that as the error to fix, while suppressing all
97
other errors caused by the same GT.
98
"""
99
100
def __init__(self, pred, gt):
101
self.pred = pred
102
self.gt = gt
103
104
if self.gt['used']:
105
self.suppress = True
106
else:
107
self.suppress = False
108
self.gt['usable'] = True
109
110
score = self.pred['score']
111
112
if not 'best_score' in self.gt:
113
self.gt['best_score'] = -1
114
115
if self.gt['best_score'] < score:
116
self.gt['best_score'] = score
117
self.gt['best_id'] = self.pred['_id']
118
119
def fix(self):
120
if self.suppress or self.gt['best_id'] != self.pred['_id']:
121
return None
122
else:
123
return (self.pred['score'], True, self.pred['info'])
124
125