Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
dbolya
GitHub Repository: dbolya/tide
Path: blob/master/tidecv/ap.py
110 views
1
from pycocotools import mask as mask_utils
2
3
from collections import defaultdict
4
import numpy as np
5
6
from .data import Data
7
from . import functions as f
8
9
10
11
12
class APDataObject:
13
"""
14
Stores all the information necessary to calculate the AP for one IoU and one class.
15
Note: I type annotated this because why not.
16
"""
17
18
def __init__(self):
19
self.data_points = {}
20
self.false_negatives = set()
21
self.num_gt_positives = 0
22
self.curve = None
23
24
def apply_qualifier(self, kept_preds:set, kept_gts:set) -> object:
25
""" Makes a new data object where we remove the ids in the pred and gt lists. """
26
obj = APDataObject()
27
num_gt_removed = 0
28
29
for pred_id in self.data_points:
30
score, is_true, info = self.data_points[pred_id]
31
32
# If the data point we kept was a true positive, there's a corresponding ground truth
33
# If so, we should only add that positive if the corresponding ground truth has been kept
34
if is_true and info['matched_with'] not in kept_gts:
35
num_gt_removed += 1
36
continue
37
38
if pred_id in kept_preds:
39
obj.data_points[pred_id] = self.data_points[pred_id]
40
41
# Propogate the gt
42
obj.false_negatives = self.false_negatives.intersection(kept_gts)
43
num_gt_removed += (len(self.false_negatives) - len(obj.false_negatives))
44
45
obj.num_gt_positives = self.num_gt_positives - num_gt_removed
46
return obj
47
48
def push(self, id:int, score:float, is_true:bool, info:dict={}):
49
self.data_points[id] = (score, is_true, info)
50
51
def push_false_negative(self, id:int):
52
self.false_negatives.add(id)
53
54
def add_gt_positives(self, num_positives:int):
55
""" Call this once per image. """
56
self.num_gt_positives += num_positives
57
58
def is_empty(self) -> bool:
59
return len(self.data_points) == 0 and self.num_gt_positives == 0
60
61
def get_pr_curve(self) -> tuple:
62
if self.curve is None:
63
self.get_ap()
64
return self.curve
65
66
def get_ap(self) -> float:
67
""" Warning: result not cached. """
68
69
if self.num_gt_positives == 0:
70
return 0
71
72
# Sort descending by score
73
data_points = list(self.data_points.values())
74
data_points.sort(key=lambda x: -x[0])
75
76
precisions = []
77
recalls = []
78
num_true = 0
79
num_false = 0
80
81
# Compute the precision-recall curve. The x axis is recalls and the y axis precisions.
82
for datum in data_points:
83
# datum[1] is whether the detection a true or false positive
84
if datum[1]: num_true += 1
85
else: num_false += 1
86
87
precision = num_true / (num_true + num_false)
88
recall = num_true / self.num_gt_positives
89
90
precisions.append(precision)
91
recalls.append(recall)
92
93
# Smooth the curve by computing [max(precisions[i:]) for i in range(len(precisions))]
94
# Basically, remove any temporary dips from the curve.
95
# At least that's what I think, idk. COCOEval did it so I do too.
96
for i in range(len(precisions)-1, 0, -1):
97
if precisions[i] > precisions[i-1]:
98
precisions[i-1] = precisions[i]
99
100
# Compute the integral of precision(recall) d_recall from recall=0->1 using fixed-length riemann summation with 101 bars.
101
resolution = 100 # Standard COCO Resoluton
102
y_range = [0] * (resolution + 1) # idx 0 is recall == 0.0 and idx 100 is recall == 1.00
103
x_range = np.array([x / resolution for x in range(resolution + 1)])
104
recalls = np.array(recalls)
105
106
# I realize this is weird, but all it does is find the nearest precision(x) for a given x in x_range.
107
# Basically, if the closest recall we have to 0.01 is 0.009 this sets precision(0.01) = precision(0.009).
108
# I approximate the integral this way, because that's how COCOEval does it.
109
indices = np.searchsorted(recalls, x_range, side='left')
110
for bar_idx, precision_idx in enumerate(indices):
111
if precision_idx < len(precisions):
112
y_range[bar_idx] = precisions[precision_idx]
113
114
self.curve = (x_range, y_range)
115
116
# Finally compute the riemann sum to get our integral.
117
# avg([precision(x) for x in 0:0.01:1])
118
return sum(y_range) / len(y_range) * 100
119
120
121
122
class ClassedAPDataObject:
123
""" Stores an APDataObject for each class in the dataset. """
124
125
def __init__(self):
126
self.objs = defaultdict(lambda: APDataObject())
127
128
def apply_qualifier(self, pred_dict:dict, gt_dict:dict) -> object:
129
ret = ClassedAPDataObject()
130
131
for _class, obj in self.objs.items():
132
pred_list = pred_dict[_class] if _class in pred_dict else set()
133
gt_list = gt_dict[_class] if _class in gt_dict else set()
134
135
ret.objs[_class] = obj.apply_qualifier(pred_list, gt_list)
136
137
return ret
138
139
def push(self, class_:int, id:int, score:float, is_true:bool, info:dict={}):
140
self.objs[class_].push(id, score, is_true, info)
141
142
def push_false_negative(self, class_:int, id:int):
143
self.objs[class_].push_false_negative(id)
144
145
def add_gt_positives(self, class_:int, num_positives:int):
146
self.objs[class_].add_gt_positives(num_positives)
147
148
def get_mAP(self) -> float:
149
aps = [x.get_ap() for x in self.objs.values() if not x.is_empty()]
150
return sum(aps) / len(aps)
151
152
def get_gt_positives(self) -> dict:
153
return {k: v.num_gt_positives for k, v in self.objs.items()}
154
155
def get_pr_curve(self, cat_id:int=None) -> tuple:
156
if cat_id is None:
157
# Average out the curves when using all categories
158
curves = [x.get_pr_curve() for x in list(self.objs.values())]
159
x_range = curves[0][0]
160
y_range = [0] * len(curves[0][1])
161
162
for x, y in curves:
163
for i in range(len(y)):
164
y_range[i] += y[i]
165
166
for i in range(len(y_range)):
167
y_range[i] /= len(curves)
168
else:
169
x_range, y_range = self.objs[cat_id].get_pr_curve()
170
171
return x_range, y_range
172
173
174
175
176
177
178
179
180
181
182
183
# Note: Unused.
184
class APEval:
185
"""
186
A class that computes the AP of some dataset.
187
Note that TIDE doesn't use this internally.
188
This is here so you can get a look at how the AP calculation process works.
189
"""
190
191
def __init__(self):
192
self.iou_thresholds = [x / 100 for x in range(50, 100, 5)]
193
self.ap_data = defaultdict(lambda: defaultdict(lambda: APDataObject()))
194
195
196
def _eval_image(self, preds:list, gt:list, type_str:str='box'):
197
data_type = 'segmentation' if type_str == 'mask' else 'bbox'
198
preds_data = [x[data_type] for x in preds]
199
200
# Split gt and crowd annotations
201
gt_new = []
202
gt_crowd = []
203
204
for x in gt:
205
if x['iscrowd']:
206
gt_crowd.append(x)
207
else:
208
gt_new.append(x)
209
210
gt = gt_new
211
212
# Some setup
213
num_pred = len(preds)
214
num_gt = len(gt)
215
num_crowd = len(gt_crowd)
216
217
iou_cache = mask_utils.iou(
218
preds_data,
219
[x[data_type] for x in gt],
220
[False] * num_gt)
221
222
if num_crowd > 0:
223
crowd_iou_cache = mask_utils.iou(
224
preds_data,
225
[x[data_type] for x in gt_crowd],
226
[True] * num_crowd)
227
228
# Make sure we're evaluating sorted by score
229
indices = list(range(num_pred))
230
indices.sort(key=lambda i: -preds[i]['score'])
231
232
classes = [x['category_id'] for x in preds]
233
gt_classes = [x['category_id'] for x in gt]
234
crowd_classes = [x['category_id'] for x in gt_crowd]
235
236
for _class in set(classes + gt_classes):
237
ap_per_iou = []
238
num_gt_for_class = sum([1 for x in gt_classes if x == _class])
239
240
for iouIdx in range(len(self.iou_thresholds)):
241
iou_threshold = self.iou_thresholds[iouIdx]
242
243
gt_used = [False] * len(gt_classes)
244
245
ap_obj = self.ap_data[iouIdx][_class]
246
ap_obj.add_gt_positives(num_gt_for_class)
247
248
for i in indices:
249
if classes[i] != _class:
250
continue
251
252
max_iou_found = iou_threshold
253
max_match_idx = -1
254
for j in range(num_gt):
255
if gt_used[j] or gt_classes[j] != _class:
256
continue
257
258
iou = iou_cache[i][j]
259
260
if iou > max_iou_found:
261
max_iou_found = iou
262
max_match_idx = j
263
264
if max_match_idx >= 0:
265
gt_used[max_match_idx] = True
266
ap_obj.push(preds[i]['score'], True)
267
else:
268
# If the detection matches a crowd, we can just ignore it
269
matched_crowd = False
270
271
if num_crowd > 0:
272
for j in range(len(crowd_classes)):
273
if crowd_classes[j] != _class:
274
continue
275
276
iou = crowd_iou_cache[i][j]
277
278
if iou > iou_threshold:
279
matched_crowd = True
280
break
281
282
# All this crowd code so that we can make sure that our eval code gives the
283
# same result as COCOEval. There aren't even that many crowd annotations to
284
# begin with, but accuracy is of the utmost importance.
285
if not matched_crowd:
286
ap_obj.push(preds[i]['score'], False)
287
288
def evaluate(self, preds:Data, gt:Data, type_str:str='box'):
289
for id in gt.ids:
290
x = preds.get(id)
291
y = gt.get(id)
292
293
self._eval_image(x, y, type_str)
294
295
def compute_mAP(self):
296
297
num_threshs = len(self.ap_data)
298
thresh_APs = []
299
300
for thresh, classes in self.ap_data.items():
301
num_classes = len([x for x in classes.values() if not x.is_empty()])
302
ap = 0
303
304
if num_classes > 0:
305
class_APs = [x.get_ap() for x in classes.values() if not x.is_empty()]
306
ap = sum(class_APs) / num_classes
307
308
thresh_APs.append(ap)
309
310
return round(sum(thresh_APs) / num_threshs * 100, 2)
311
312
313