Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
dbolya
GitHub Repository: dbolya/tide
Path: blob/master/tidecv/quantify.py
110 views
1
from .data import Data
2
from .ap import ClassedAPDataObject
3
from .errors.main_errors import *
4
from .errors.qualifiers import Qualifier, AREA
5
from . import functions as f
6
from . import plotting as P
7
8
from pycocotools import mask as mask_utils
9
from collections import defaultdict, OrderedDict
10
import numpy as np
11
from typing import Union
12
import os, math
13
14
class TIDEExample:
15
""" Computes all the data needed to evaluate a set of predictions and gt for a single image. """
16
def __init__(self, preds:list, gt:list, pos_thresh:float, mode:str, max_dets:int, run_errors:bool=True):
17
self.preds = preds
18
self.gt = [x for x in gt if not x['ignore']]
19
self.ignore_regions = [x for x in gt if x['ignore']]
20
21
self.mode = mode
22
self.pos_thresh = pos_thresh
23
self.max_dets = max_dets
24
self.run_errors = run_errors
25
26
self._run()
27
28
def _run(self):
29
preds = self.preds
30
gt = self.gt
31
ignore = self.ignore_regions
32
det_type = 'bbox' if self.mode == TIDE.BOX else 'mask'
33
max_dets = self.max_dets
34
35
if len(preds) == 0:
36
raise RuntimeError('Example has no predictions!')
37
38
39
# Sort descending by score
40
preds.sort(key=lambda pred: -pred['score'])
41
preds = preds[:max_dets]
42
self.preds = preds # Update internally so TIDERun can update itself if :max_dets takes effect
43
detections = [x[det_type] for x in preds]
44
45
46
# IoU is [len(detections), len(gt)]
47
self.gt_iou = mask_utils.iou(
48
detections,
49
[x[det_type] for x in gt],
50
[False] * len(gt))
51
52
# Store whether a prediction / gt got used in their data list
53
# Note: this is set to None if ignored, keep that in mind
54
for idx, pred in enumerate(preds):
55
pred['used'] = False
56
pred['_idx'] = idx
57
pred['iou'] = 0
58
for idx, truth in enumerate(gt):
59
truth['used'] = False
60
truth['usable'] = False
61
truth['_idx'] = idx
62
63
pred_cls = np.array([x['class'] for x in preds])
64
gt_cls = np.array([x['class'] for x in gt])
65
66
if len(gt) > 0:
67
# A[i,j] is true iff the prediction i is of the same class as gt j
68
self.gt_cls_matching = (pred_cls[:, None] == gt_cls[None, :])
69
self.gt_cls_iou = self.gt_iou * self.gt_cls_matching
70
71
# This will be changed in the matching calculation, so make a copy
72
iou_buffer = self.gt_cls_iou.copy()
73
74
for pred_idx, pred_elem in enumerate(preds):
75
# Find the max iou ground truth for this prediction
76
gt_idx = np.argmax(iou_buffer[pred_idx, :])
77
iou = iou_buffer[pred_idx, gt_idx]
78
79
pred_elem['iou'] = np.max(self.gt_cls_iou[pred_idx, :])
80
81
if iou >= self.pos_thresh:
82
gt_elem = gt[gt_idx]
83
84
pred_elem['used'] = True
85
gt_elem['used'] = True
86
pred_elem['matched_with'] = gt_elem['_id']
87
gt_elem['matched_with'] = pred_elem['_id']
88
89
# Make sure this gt can't be used again
90
iou_buffer[:, gt_idx] = 0
91
92
# Ignore regions annotations allow us to ignore predictions that fall within
93
if len(ignore) > 0:
94
# Because ignore regions have extra parameters, it's more efficient to use a for loop here
95
for ignore_region in ignore:
96
if ignore_region['mask'] is None and ignore_region['bbox'] is None:
97
# The region should span the whole image
98
ignore_iou = [1] * len(preds)
99
else:
100
if ignore_region[det_type] is None:
101
# There is no det_type annotation for this specific region so skip it
102
continue
103
# Otherwise, compute the crowd IoU between the detections and this region
104
ignore_iou = mask_utils.iou(detections, [ignore_region[det_type]], [True])
105
106
for pred_idx, pred_elem in enumerate(preds):
107
if not pred_elem['used'] and (ignore_iou[pred_idx] > self.pos_thresh) \
108
and (ignore_region['class'] == pred_elem['class'] or ignore_region['class'] == -1):
109
# Set the prediction to be ignored
110
pred_elem['used'] = None
111
112
if len(gt) == 0:
113
return
114
115
# Some matrices used just for error calculation
116
if self.run_errors:
117
self.gt_used = np.array([x['used'] == True for x in gt])[None, :]
118
self.gt_unused = ~self.gt_used
119
120
self.gt_unused_iou = self.gt_unused * self.gt_iou
121
self.gt_unused_cls = self.gt_unused_iou * self.gt_cls_matching
122
self.gt_unused_noncls = self.gt_unused_iou * ~self.gt_cls_matching
123
124
self.gt_noncls_iou = self.gt_iou * ~self.gt_cls_matching
125
126
self.gt_used_iou = self.gt_used * self.gt_iou
127
self.gt_used_cls = self.gt_used_iou * self.gt_cls_matching
128
129
130
class TIDERun:
131
""" Holds the data for a single run of TIDE. """
132
133
# Temporary variables stored in ground truth that we need to clear after a run
134
_temp_vars = ['best_score', 'best_id', 'used', 'matched_with', '_idx', 'usable']
135
136
def __init__(self, gt:Data, preds:Data, pos_thresh:float, bg_thresh:float, mode:str, max_dets:int, run_errors:bool=True):
137
self.gt = gt
138
self.preds = preds
139
140
self.errors = []
141
self.error_dict = {_type: [] for _type in TIDE._error_types}
142
self.ap_data = ClassedAPDataObject()
143
self.qualifiers = {}
144
145
# A list of false negatives per class
146
self.false_negatives = {_id: [] for _id in self.gt.classes}
147
148
self.pos_thresh = pos_thresh
149
self.bg_thresh = bg_thresh
150
self.mode = mode
151
self.max_dets = max_dets
152
self.run_errors = run_errors
153
154
self._run()
155
156
157
def _run(self):
158
""" And awaaay we go """
159
160
for image in self.gt.images:
161
x = self.preds.get(image)
162
y = self.gt.get(image)
163
164
# These classes are ignored for the whole image and not in the ground truth, so
165
# we can safely just remove these detections from the predictions at the start.
166
# However, since ignored detections are still used for error calculations, we have to keep them.
167
if not self.run_errors:
168
ignored_classes = self.gt._get_ignored_classes(image)
169
x = [pred for pred in x if pred['class'] not in ignored_classes]
170
171
self._eval_image(x, y)
172
173
# Store a fixed version of all the errors for testing purposes
174
for error in self.errors:
175
error.original = f.nonepack(error.unfix())
176
error.fixed = f.nonepack(error.fix())
177
error.disabled = False
178
179
self.ap = self.ap_data.get_mAP()
180
181
# Now that we've stored the fixed errors, we can clear the gt info
182
self._clear()
183
184
185
186
187
def _clear(self):
188
""" Clears the ground truth so that it's ready for another run. """
189
for gt in self.gt.annotations:
190
for var in self._temp_vars:
191
if var in gt:
192
del gt[var]
193
194
def _add_error(self, error):
195
self.errors.append(error)
196
self.error_dict[type(error)].append(error)
197
198
def _eval_image(self, preds:list, gt:list):
199
200
for truth in gt:
201
if not truth['ignore']:
202
self.ap_data.add_gt_positives(truth['class'], 1)
203
204
if len(preds) == 0:
205
# There are no predictions for this image so add all gt as missed
206
for truth in gt:
207
if not truth['ignore']:
208
self.ap_data.push_false_negative(truth['class'], truth['_id'])
209
210
if self.run_errors:
211
self._add_error(MissedError(truth))
212
self.false_negatives[truth['class']].append(truth)
213
return
214
215
ex = TIDEExample(preds, gt, self.pos_thresh, self.mode, self.max_dets, self.run_errors)
216
preds = ex.preds # In case the number of predictions was restricted to the max
217
218
for pred_idx, pred in enumerate(preds):
219
220
pred['info'] = {'iou': pred['iou'], 'used': pred['used']}
221
if pred['used']: pred['info']['matched_with'] = pred['matched_with']
222
223
if pred['used'] is not None:
224
self.ap_data.push(pred['class'], pred['_id'], pred['score'], pred['used'], pred['info'])
225
226
# ----- ERROR DETECTION ------ #
227
# This prediction is a negative (or ignored), let's find out why
228
if self.run_errors and (pred['used'] == False or pred['used'] == None):
229
# Test for BackgroundError
230
if len(ex.gt) == 0: # Note this is ex.gt because it doesn't include ignore annotations
231
# There is no ground truth for this image, so just mark everything as BackgroundError
232
self._add_error(BackgroundError(pred))
233
continue
234
235
# Test for BoxError
236
idx = ex.gt_cls_iou[pred_idx, :].argmax()
237
if self.bg_thresh <= ex.gt_cls_iou[pred_idx, idx] <= self.pos_thresh:
238
# This detection would have been positive if it had higher IoU with this GT
239
self._add_error(BoxError(pred, ex.gt[idx], ex))
240
continue
241
242
# Test for ClassError
243
idx = ex.gt_noncls_iou[pred_idx, :].argmax()
244
if ex.gt_noncls_iou[pred_idx, idx] >= self.pos_thresh:
245
# This detection would have been a positive if it was the correct class
246
self._add_error(ClassError(pred, ex.gt[idx], ex))
247
continue
248
249
# Test for DuplicateError
250
idx = ex.gt_used_cls[pred_idx, :].argmax()
251
if ex.gt_used_cls[pred_idx, idx] >= self.pos_thresh:
252
# The detection would have been marked positive but the GT was already in use
253
suppressor = self.preds.annotations[ex.gt[idx]['matched_with']]
254
self._add_error(DuplicateError(pred, suppressor))
255
continue
256
257
# Test for BackgroundError
258
idx = ex.gt_iou[pred_idx, :].argmax()
259
if ex.gt_iou[pred_idx, idx] <= self.bg_thresh:
260
# This should have been marked as background
261
self._add_error(BackgroundError(pred))
262
continue
263
264
# A base case to catch uncaught errors
265
self._add_error(OtherError(pred))
266
267
for truth in gt:
268
# If the GT wasn't used in matching, meaning it's some kind of false negative
269
if not truth['ignore'] and not truth['used']:
270
self.ap_data.push_false_negative(truth['class'], truth['_id'])
271
272
if self.run_errors:
273
self.false_negatives[truth['class']].append(truth)
274
275
# The GT was completely missed, no error can correct it
276
# Note: 'usable' is set in error.py
277
if not truth['usable']:
278
self._add_error(MissedError(truth))
279
280
281
282
def fix_errors(self, condition=lambda x: False, transform=None, false_neg_dict:dict=None,
283
ap_data:ClassedAPDataObject=None,
284
disable_errors:bool=False) -> ClassedAPDataObject:
285
""" Returns a ClassedAPDataObject where all errors given the condition returns True are fixed. """
286
if ap_data is None:
287
ap_data = self.ap_data
288
289
gt_pos = ap_data.get_gt_positives()
290
new_ap_data = ClassedAPDataObject()
291
292
# Potentially fix every error case
293
for error in self.errors:
294
if error.disabled:
295
continue
296
297
_id = error.get_id()
298
_cls, data_point = error.original
299
300
if condition(error):
301
_cls, data_point = error.fixed
302
303
if disable_errors:
304
error.disabled = True
305
306
# Specific for MissingError (or anything else that affects #GT)
307
if isinstance(data_point, int):
308
gt_pos[_cls] += data_point
309
data_point = None
310
311
if data_point is not None:
312
if transform is not None:
313
data_point = transform(*data_point)
314
new_ap_data.push(_cls, _id, *data_point)
315
316
# Add back all the correct ones
317
for k in gt_pos.keys():
318
for _id, (score, correct, info) in ap_data.objs[k].data_points.items():
319
if correct:
320
if transform is not None:
321
score, correct, info = transform(score, correct, info)
322
new_ap_data.push(k, _id, score, correct, info)
323
324
# Add the correct amount of GT positives, and also subtract if necessary
325
for k, v in gt_pos.items():
326
# In case you want to fix all false negatives without affecting precision
327
if false_neg_dict is not None and k in false_neg_dict:
328
v -= len(false_neg_dict[k])
329
new_ap_data.add_gt_positives(k, v)
330
331
return new_ap_data
332
333
def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qualifier=None) -> dict:
334
ap_data = self.ap_data
335
last_ap = self.ap
336
337
if qual is None:
338
qual = Qualifier('', None)
339
340
if error_types is None:
341
error_types = TIDE._error_types
342
343
errors = {}
344
345
for error in error_types:
346
_ap_data = self.fix_errors(qual._make_error_func(error),
347
ap_data=ap_data, disable_errors=progressive)
348
349
new_ap = _ap_data.get_mAP()
350
# If an error is negative that means it's likely due to binning differences, so just
351
# Ignore the negative by setting it to 0.
352
errors[error] = max(new_ap - last_ap, 0)
353
354
if progressive:
355
last_ap = new_ap
356
ap_data = _ap_data
357
358
if progressive:
359
for error in self.errors:
360
error.disabled = False
361
362
return errors
363
364
def fix_special_errors(self, qual=None) -> dict:
365
return {
366
FalsePositiveError: self.fix_errors(transform=FalsePositiveError.fix).get_mAP() - self.ap,
367
FalseNegativeError: self.fix_errors(false_neg_dict=self.false_negatives).get_mAP() - self.ap}
368
369
def count_errors(self, error_types:list=None, qual=None):
370
counts = {}
371
372
if error_types is None:
373
error_types = TIDE._error_types
374
375
for error in error_types:
376
if qual is None:
377
counts[error] = len(self.error_dict[error])
378
else:
379
func = qualifiers.make_qualifier(error, qual)
380
counts[error] = len([x for x in self.errors if func(x)])
381
382
return counts
383
384
385
def apply_qualifier(self, qualifier:Qualifier) -> ClassedAPDataObject:
386
""" Applies a qualifier lambda to the AP object for this runs and stores the result in self.qualifiers. """
387
388
pred_keep = defaultdict(lambda: set())
389
gt_keep = defaultdict(lambda: set())
390
391
for pred in self.preds.annotations:
392
if qualifier.test(pred):
393
pred_keep[pred['class']].add(pred['_id'])
394
395
for gt in self.gt.annotations:
396
if not gt['ignore'] and qualifier.test(gt):
397
gt_keep[gt['class']].add(gt['_id'])
398
399
new_ap_data = self.ap_data.apply_qualifier(pred_keep, gt_keep)
400
self.qualifiers[qualifier.name] = new_ap_data.get_mAP()
401
return new_ap_data
402
403
404
405
class TIDE:
406
"""
407
████████╗██╗██████╗ ███████╗
408
╚══██╔══╝██║██╔══██╗██╔════╝
409
██║ ██║██║ ██║█████╗
410
██║ ██║██║ ██║██╔══╝
411
██║ ██║██████╔╝███████╗
412
╚═╝ ╚═╝╚═════╝ ╚══════╝
413
"""
414
415
416
# This is just here to define a consistent order of the error types
417
_error_types = [ClassError, BoxError, OtherError, DuplicateError, BackgroundError, MissedError]
418
_special_error_types = [FalsePositiveError, FalseNegativeError]
419
420
# Threshold splits for different challenges
421
COCO_THRESHOLDS = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
422
VOL_THRESHOLDS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
423
424
# The modes of evaluation
425
BOX = 'bbox'
426
MASK = 'mask'
427
428
def __init__(self, pos_threshold:float=0.5, background_threshold:float=0.1, mode:str=BOX):
429
self.pos_thresh = pos_threshold
430
self.bg_thresh = background_threshold
431
self.mode = mode
432
433
self.pos_thresh_int = int(self.pos_thresh * 100)
434
435
self.runs = {}
436
self.run_thresholds = {}
437
self.run_main_errors = {}
438
self.run_special_errors = {}
439
440
self.qualifiers = OrderedDict()
441
442
self.plotter = P.Plotter()
443
444
445
def evaluate(self, gt:Data, preds:Data, pos_threshold:float=None, background_threshold:float=None,
446
mode:str=None, name:str=None, use_for_errors:bool=True) -> TIDERun:
447
pos_thresh = self.pos_thresh if pos_threshold is None else pos_threshold
448
bg_thresh = self.bg_thresh if background_threshold is None else background_threshold
449
mode = self.mode if mode is None else mode
450
name = preds.name if name is None else name
451
452
run = TIDERun(gt, preds, pos_thresh, bg_thresh, mode, gt.max_dets, use_for_errors)
453
454
if use_for_errors:
455
self.runs[name] = run
456
457
return run
458
459
def evaluate_range(self, gt:Data, preds:Data, thresholds:list=COCO_THRESHOLDS, pos_threshold:float=None,
460
background_threshold:float=None, mode:str=None, name:str=None) -> dict:
461
462
if pos_threshold is None: pos_threshold = self.pos_thresh
463
if name is None: name = preds.name
464
465
self.run_thresholds[name] = []
466
467
for thresh in thresholds:
468
469
run = self.evaluate(gt, preds, pos_threshold=thresh, background_threshold=background_threshold,
470
mode=mode, name=name, use_for_errors=(pos_threshold == thresh))
471
472
self.run_thresholds[name].append(run)
473
474
def add_qualifiers(self, *quals):
475
"""
476
Applies any number of Qualifier objects to evaluations that have been run up to now.
477
See qualifiers.py for examples.
478
"""
479
raise NotImplementedError('Qualifiers coming soon.')
480
# for q in quals:
481
# for run_name, run in self.runs.items():
482
# if run_name in self.run_thresholds:
483
# # If this was a threshold run, apply the qualifier for every run
484
# for trun in self.run_thresholds[run_name]:
485
# trun.apply_qualifier(q)
486
# else:
487
# # If this had no threshold, just apply it to the main run
488
# run.apply_qualifier(q)
489
490
# self.qualifiers[q.name] = q
491
492
def summarize(self):
493
""" Summarizes the mAP values and errors for all runs in this TIDE object. Results are printed to the console. """
494
main_errors = self.get_main_errors()
495
special_errors = self.get_special_errors()
496
497
for run_name, run in self.runs.items():
498
print('-- {} --\n'.format(run_name))
499
500
# If we evaluated on all thresholds, print them here
501
if run_name in self.run_thresholds:
502
thresh_runs = self.run_thresholds[run_name]
503
aps = [trun.ap for trun in thresh_runs]
504
505
# Print Overall AP for a threshold run
506
ap_title = '{} AP @ [{:d}-{:d}]'.format(thresh_runs[0].mode,
507
int(thresh_runs[0].pos_thresh*100), int(thresh_runs[-1].pos_thresh*100))
508
print('{:s}: {:.2f}'.format(ap_title, sum(aps)/len(aps)))
509
510
# Print AP for every threshold on a threshold run
511
P.print_table([
512
['Thresh'] + [str(int(trun.pos_thresh*100)) for trun in thresh_runs],
513
[' AP '] + ['{:6.2f}'.format(trun.ap) for trun in thresh_runs]
514
], title=ap_title)
515
516
# Print qualifiers for a threshold run
517
if len(self.qualifiers) > 0:
518
print()
519
# Can someone ban me from using list comprehension? this is unreadable
520
qAPs = [
521
f.mean(
522
[trun.qualifiers[q] for trun in thresh_runs if q in trun.qualifiers]
523
) for q in self.qualifiers
524
]
525
526
P.print_table([
527
['Name'] + list(self.qualifiers.keys()),
528
[' AP '] + ['{:6.2f}'.format(qAP) for qAP in qAPs]
529
], title='Qualifiers {}'.format(ap_title))
530
531
# Otherwise, print just the one run we did
532
else:
533
# Print Overall AP for a regular run
534
ap_title = '{} AP @ {:d}'.format(run.mode, int(run.pos_thresh*100))
535
print('{}: {:.2f}'.format(ap_title, run.ap))
536
537
# Print qualifiers for a regular run
538
if len(self.qualifiers) > 0:
539
print()
540
qAPs = [run.qualifiers[q] if q in run.qualifiers else 0 for q in self.qualifiers]
541
P.print_table([
542
['Name'] + list(self.qualifiers.keys()),
543
[' AP '] + ['{:6.2f}'.format(qAP) for qAP in qAPs]
544
], title='Qualifiers {}'.format(ap_title))
545
546
547
548
print()
549
# Print the main errors
550
P.print_table([
551
['Type'] + [err.short_name for err in TIDE._error_types],
552
[' dAP'] + ['{:6.2f}'.format(main_errors[run_name][err.short_name]) for err in TIDE._error_types]
553
], title='Main Errors')
554
555
556
557
print()
558
# Print the special errors
559
P.print_table([
560
['Type'] + [err.short_name for err in TIDE._special_error_types],
561
[' dAP'] + ['{:6.2f}'.format(special_errors[run_name][err.short_name]) for err in TIDE._special_error_types]
562
], title='Special Error')
563
564
print()
565
566
def plot(self, out_dir:str=None):
567
"""
568
Plots a summary model for each run in this TIDE object.
569
Images will be outputted to out_dir, which will be created if it doesn't exist.
570
"""
571
572
if out_dir is not None:
573
if not os.path.exists(out_dir):
574
os.makedirs(out_dir)
575
576
errors = self.get_all_errors()
577
578
max_main_error = max(sum([list(x.values()) for x in errors['main'].values()], []))
579
max_spec_error = max(sum([list(x.values()) for x in errors['special'].values()], []))
580
dap_granularity = 5 # The max will round up to the nearest unit of this
581
582
# Round the plotter's dAP range up to the nearest granularity units
583
if max_main_error > self.plotter.MAX_MAIN_DELTA_AP:
584
self.plotter.MAX_MAIN_DELTA_AP = math.ceil(max_main_error / dap_granularity) * dap_granularity
585
if max_spec_error > self.plotter.MAX_SPECIAL_DELTA_AP:
586
self.plotter.MAX_SPECIAL_DELTA_AP = math.ceil(max_spec_error / dap_granularity) * dap_granularity
587
588
# Do the plotting now
589
for run_name, run in self.runs.items():
590
self.plotter.make_summary_plot(out_dir, errors, run_name, run.mode, hbar_names=True)
591
592
593
594
def get_main_errors(self):
595
errors = {}
596
597
for run_name, run in self.runs.items():
598
if run_name in self.run_main_errors:
599
errors[run_name] = self.run_main_errors[run_name]
600
else:
601
errors[run_name] = {
602
error.short_name: value
603
for error, value in run.fix_main_errors().items()
604
}
605
606
return errors
607
608
def get_special_errors(self):
609
errors = {}
610
611
for run_name, run in self.runs.items():
612
if run_name in self.run_special_errors:
613
errors[run_name] = self.run_special_errors[run_name]
614
else:
615
errors[run_name] = {
616
error.short_name: value
617
for error, value in run.fix_special_errors().items()
618
}
619
620
return errors
621
622
def get_all_errors(self):
623
"""
624
returns {
625
'main' : { run_name: { error_name: float } },
626
'special': { run_name: { error_name: float } },
627
}
628
"""
629
return {
630
'main': self.get_main_errors(),
631
'special': self.get_special_errors()
632
}
633
634
635
636