Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
dbolya
GitHub Repository: dbolya/tide
Path: blob/master/tidecv/plotting.py
110 views
1
2
from collections import defaultdict, OrderedDict
3
import os
4
import shutil
5
6
import cv2
7
from matplotlib.lines import Line2D
8
import matplotlib.pyplot as plt
9
import matplotlib as mpl
10
import numpy as np
11
import pandas as pd
12
import seaborn as sns
13
14
from .errors.main_errors import *
15
from .datasets import get_tide_path
16
17
18
19
def print_table(rows:list, title:str=None):
20
# Get all rows to have the same number of columns
21
max_cols = max([len(row) for row in rows])
22
for row in rows:
23
while len(row) < max_cols:
24
row.append('')
25
26
# Compute the text width of each column
27
col_widths = [max([len(rows[i][col_idx]) for i in range(len(rows))]) for col_idx in range(len(rows[0]))]
28
29
divider = '--' + ('---'.join(['-' * w for w in col_widths])) + '-'
30
thick_divider = divider.replace('-', '=')
31
32
if title:
33
left_pad = (len(divider) - len(title)) // 2
34
print(('{:>%ds}' % (left_pad + len(title))).format(title))
35
36
print(thick_divider)
37
for row in rows:
38
# Print each row while padding to each column's text width
39
print(' ' + ' '.join([('{:>%ds}' % col_widths[col_idx]).format(row[col_idx]) for col_idx in range(len(row))]) + ' ')
40
if row == rows[0]: print(divider)
41
print(thick_divider)
42
43
44
45
46
class Plotter():
47
""" Sets up a seaborn environment and holds the functions for plotting our figures. """
48
49
50
def __init__(self, quality:float=1):
51
# Set mpl DPI in case we want to output to the screen / notebook
52
mpl.rcParams['figure.dpi'] = 150
53
54
# Seaborn color palette
55
sns.set_palette('muted', 10)
56
current_palette = sns.color_palette()
57
58
# Seaborn style
59
sns.set(style="whitegrid")
60
61
self.colors_main = OrderedDict({
62
ClassError .short_name: current_palette[9],
63
BoxError .short_name: current_palette[8],
64
OtherError .short_name: current_palette[2],
65
DuplicateError .short_name: current_palette[6],
66
BackgroundError .short_name: current_palette[4],
67
MissedError .short_name: current_palette[3],
68
})
69
70
self.colors_special = OrderedDict({
71
FalsePositiveError.short_name: current_palette[0],
72
FalseNegativeError.short_name: current_palette[1],
73
})
74
75
self.tide_path = get_tide_path()
76
77
# For the purposes of comparing across models, we fix the scales on our bar plots.
78
# Feel free to change these after initializing if you want to change the scale.
79
self.MAX_MAIN_DELTA_AP = 10
80
self.MAX_SPECIAL_DELTA_AP = 25
81
82
self.quality = quality
83
84
def _prepare_tmp_dir(self):
85
tmp_dir = os.path.join(self.tide_path, '_tmp')
86
87
if not os.path.exists(tmp_dir):
88
os.makedirs(tmp_dir)
89
90
for _f in os.listdir(tmp_dir):
91
os.remove(os.path.join(tmp_dir, _f))
92
93
return tmp_dir
94
95
96
def make_summary_plot(self, out_dir:str, errors:dict, model_name:str, rec_type:str, hbar_names:bool=False):
97
"""Make a summary plot of the errors for a model, and save it to the figs folder.
98
99
:param out_dir: The output directory for the summary image. MUST EXIST.
100
:param errors: Dictionary of both main and special errors.
101
:param model_name: Name of the model for which to generate the plot.
102
:param rec_type: Recognition type, either TIDE.BOX or TIDE.MASK
103
:param hbar_names: Whether or not to include labels for the horizontal bars.
104
"""
105
106
tmp_dir = self._prepare_tmp_dir()
107
108
high_dpi = int(500*self.quality)
109
low_dpi = int(300*self.quality)
110
111
# get the data frame
112
error_dfs = {errtype: pd.DataFrame(data={
113
'Error Type': list(errors[errtype][model_name].keys()),
114
'Delta mAP': list(errors[errtype][model_name].values()),
115
}) for errtype in ['main', 'special']}
116
117
# pie plot for error type breakdown
118
error_types = list(errors['main'][model_name].keys()) + list(errors['special'][model_name].keys())
119
error_sum = sum([e for e in errors['main'][model_name].values()])
120
error_sizes = [e / error_sum for e in errors['main'][model_name].values()] + [0, 0]
121
fig, ax = plt.subplots(1, 1, figsize=(11, 11), dpi=high_dpi)
122
patches, outer_text, inner_text = ax.pie(error_sizes, colors=self.colors_main.values(), labels=error_types,
123
autopct='%1.1f%%', startangle=90)
124
for text in outer_text + inner_text:
125
text.set_text('')
126
for i in range(len(self.colors_main)):
127
if error_sizes[i] > 0.05:
128
inner_text[i].set_text(list(self.colors_main.keys())[i])
129
inner_text[i].set_fontsize(48)
130
inner_text[i].set_fontweight('bold')
131
ax.axis('equal')
132
plt.title(model_name, fontdict={'fontsize': 60, 'fontweight': 'bold'})
133
pie_path = os.path.join(tmp_dir, '{}_{}_pie.png'.format(model_name, rec_type))
134
plt.savefig(pie_path, bbox_inches='tight', dpi=low_dpi)
135
plt.close()
136
137
# horizontal bar plot for main error types
138
fig, ax = plt.subplots(1, 1, figsize = (6, 5), dpi=high_dpi)
139
sns.barplot(data=error_dfs['main'], x='Delta mAP', y='Error Type', ax=ax,
140
palette=self.colors_main.values())
141
ax.set_xlim(0, self.MAX_MAIN_DELTA_AP)
142
ax.set_xlabel('')
143
ax.set_ylabel('')
144
if not hbar_names:
145
ax.set_yticklabels([''] * 6)
146
plt.setp(ax.get_xticklabels(), fontsize=28)
147
plt.setp(ax.get_yticklabels(), fontsize=36)
148
ax.grid(False)
149
sns.despine(left=True, bottom=True, right=True)
150
hbar_path = os.path.join(tmp_dir, '{}_{}_hbar.png'.format(model_name, rec_type))
151
plt.savefig(hbar_path, bbox_inches='tight', dpi=low_dpi)
152
plt.close()
153
154
# vertical bar plot for special error types
155
fig, ax = plt.subplots(1, 1, figsize = (2, 5), dpi=high_dpi)
156
sns.barplot(data=error_dfs['special'], x='Error Type', y='Delta mAP', ax=ax,
157
palette=self.colors_special.values())
158
ax.set_ylim(0, self.MAX_SPECIAL_DELTA_AP)
159
ax.set_xlabel('')
160
ax.set_ylabel('')
161
ax.set_xticklabels(['FP', 'FN'])
162
plt.setp(ax.get_xticklabels(), fontsize=36)
163
plt.setp(ax.get_yticklabels(), fontsize=28)
164
ax.grid(False)
165
sns.despine(left=True, bottom=True, right=True)
166
vbar_path = os.path.join(tmp_dir, '{}_{}_vbar.png'.format(model_name, rec_type))
167
plt.savefig(vbar_path, bbox_inches='tight', dpi=low_dpi)
168
plt.close()
169
170
# get each subplot image
171
pie_im = cv2.imread(pie_path)
172
hbar_im = cv2.imread(hbar_path)
173
vbar_im = cv2.imread(vbar_path)
174
175
# pad the hbar image vertically
176
hbar_im = np.concatenate([np.zeros((vbar_im.shape[0] - hbar_im.shape[0], hbar_im.shape[1], 3)) + 255, hbar_im],
177
axis=0)
178
summary_im = np.concatenate([hbar_im, vbar_im], axis=1)
179
180
# pad summary_im
181
if summary_im.shape[1]<pie_im.shape[1]:
182
lpad, rpad = int(np.ceil((pie_im.shape[1] - summary_im.shape[1])/2)), \
183
int(np.floor((pie_im.shape[1] - summary_im.shape[1])/2))
184
summary_im = np.concatenate([np.zeros((summary_im.shape[0], lpad, 3)) + 255,
185
summary_im,
186
np.zeros((summary_im.shape[0], rpad, 3)) + 255], axis=1)
187
188
# pad pie_im
189
else:
190
lpad, rpad = int(np.ceil((summary_im.shape[1] - pie_im.shape[1])/2)), \
191
int(np.floor((summary_im.shape[1] - pie_im.shape[1])/2))
192
pie_im = np.concatenate([np.zeros((pie_im.shape[0], lpad, 3)) + 255,
193
pie_im,
194
np.zeros((pie_im.shape[0], rpad, 3)) + 255], axis=1)
195
196
197
summary_im = np.concatenate([pie_im, summary_im], axis=0)
198
199
if out_dir is None:
200
fig = plt.figure()
201
ax = plt.axes([0,0,1,1])
202
ax.set_axis_off()
203
fig.add_axes(ax)
204
ax.imshow((summary_im / 255)[:, :, (2, 1, 0)])
205
plt.show()
206
plt.close()
207
else:
208
cv2.imwrite(os.path.join(out_dir, '{}_{}_summary.png'.format(model_name, rec_type)), summary_im)
209
210
211