Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/metrics.py
781 views
1
"Implements various metrics to measure training accuracy"
2
from .torch_core import *
3
from .callback import *
4
from .layers import *
5
from .basic_train import LearnerCallback
6
7
__all__ = ['error_rate', 'accuracy', 'accuracy_thresh', 'dice', 'exp_rmspe', 'fbeta','FBeta', 'mse', 'mean_squared_error',
8
'mae', 'mean_absolute_error', 'rmse', 'root_mean_squared_error', 'msle', 'mean_squared_logarithmic_error',
9
'explained_variance', 'r2_score', 'top_k_accuracy', 'KappaScore', 'ConfusionMatrix', 'MatthewsCorreff',
10
'Precision', 'Recall', 'R2Score', 'ExplainedVariance', 'ExpRMSPE', 'RMSE', 'Perplexity', 'AUROC', 'auc_roc_score',
11
'roc_curve', 'MultiLabelFbeta', 'foreground_acc']
12
13
def fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor:
14
"Computes the f_beta between `preds` and `targets`"
15
beta2 = beta ** 2
16
if sigmoid: y_pred = y_pred.sigmoid()
17
y_pred = (y_pred>thresh).float()
18
y_true = y_true.float()
19
TP = (y_pred*y_true).sum(dim=1)
20
prec = TP/(y_pred.sum(dim=1)+eps)
21
rec = TP/(y_true.sum(dim=1)+eps)
22
res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2)
23
return res.mean()
24
25
def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
26
"Computes accuracy with `targs` when `input` is bs * n_classes."
27
n = targs.shape[0]
28
input = input.argmax(dim=-1).view(n,-1)
29
targs = targs.view(n,-1)
30
return (input==targs).float().mean()
31
32
def accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
33
"Computes accuracy when `y_pred` and `y_true` are the same size."
34
if sigmoid: y_pred = y_pred.sigmoid()
35
return ((y_pred>thresh)==y_true.byte()).float().mean()
36
37
def top_k_accuracy(input:Tensor, targs:Tensor, k:int=5)->Rank0Tensor:
38
"Computes the Top-k accuracy (target is in the top k predictions)."
39
input = input.topk(k=k, dim=-1)[1]
40
targs = targs.unsqueeze(dim=-1).expand_as(input)
41
return (input == targs).max(dim=-1)[0].float().mean()
42
43
def foreground_acc(input, target, void_code):
44
"Computes non-background accuracy, e.g. camvid for multiclass segmentation"
45
target = target.squeeze(1)
46
mask = target != void_code
47
return (input.argmax(dim=1)[mask]==target[mask]).float().mean()
48
49
def error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:
50
"1 - `accuracy`"
51
return 1 - accuracy(input, targs)
52
53
def dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
54
"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
55
n = targs.shape[0]
56
input = input.argmax(dim=1).view(n,-1)
57
targs = targs.view(n,-1)
58
intersect = (input * targs).sum().float()
59
union = (input+targs).sum().float()
60
if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
61
else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze())
62
63
def psnr(input:Tensor, targs:Tensor)->Rank0Tensor:
64
return 10 * (1. / mean_squared_error(input, targs)).log10()
65
66
def exp_rmspe(pred:Tensor, targ:Tensor)->Rank0Tensor:
67
"Exp RMSE between `pred` and `targ`."
68
pred,targ = flatten_check(pred,targ)
69
pred, targ = torch.exp(pred), torch.exp(targ)
70
pct_var = (targ - pred)/targ
71
return torch.sqrt((pct_var**2).mean())
72
73
def mean_absolute_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
74
"Mean absolute error between `pred` and `targ`."
75
pred,targ = flatten_check(pred,targ)
76
return torch.abs(targ - pred).mean()
77
78
def mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
79
"Mean squared error between `pred` and `targ`."
80
pred,targ = flatten_check(pred,targ)
81
return F.mse_loss(pred, targ)
82
83
def root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
84
"Root mean squared error between `pred` and `targ`."
85
pred,targ = flatten_check(pred,targ)
86
return torch.sqrt(F.mse_loss(pred, targ))
87
88
def mean_squared_logarithmic_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
89
"Mean squared logarithmic error between `pred` and `targ`."
90
pred,targ = flatten_check(pred,targ)
91
return F.mse_loss(torch.log(1 + pred), torch.log(1 + targ))
92
93
def explained_variance(pred:Tensor, targ:Tensor)->Rank0Tensor:
94
"Explained variance between `pred` and `targ`."
95
pred,targ = flatten_check(pred,targ)
96
var_pct = torch.var(targ - pred) / torch.var(targ)
97
return 1 - var_pct
98
99
def r2_score(pred:Tensor, targ:Tensor)->Rank0Tensor:
100
"R2 score (coefficient of determination) between `pred` and `targ`."
101
pred,targ = flatten_check(pred,targ)
102
u = torch.sum((targ - pred) ** 2)
103
d = torch.sum((targ - targ.mean()) ** 2)
104
return 1 - u / d
105
106
class RegMetrics(Callback):
107
"Stores predictions and targets to perform calculations on epoch end."
108
def on_epoch_begin(self, **kwargs):
109
self.targs, self.preds = Tensor([]), Tensor([])
110
111
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
112
assert last_output.numel() == last_target.numel(), "Expected same numbers of elements in pred & targ"
113
self.preds = torch.cat((self.preds, last_output.cpu()))
114
self.targs = torch.cat((self.targs, last_target.cpu()))
115
116
class R2Score(RegMetrics):
117
"Computes the R2 score (coefficient of determination)."
118
def on_epoch_end(self, last_metrics, **kwargs):
119
return add_metrics(last_metrics, r2_score(self.preds, self.targs))
120
121
class ExplainedVariance(RegMetrics):
122
"Computes the explained variance."
123
def on_epoch_end(self, last_metrics, **kwargs):
124
return add_metrics(last_metrics, explained_variance(self.preds, self.targs))
125
126
class RMSE(RegMetrics):
127
"Computes the root mean squared error."
128
def on_epoch_end(self, last_metrics, **kwargs):
129
return add_metrics(last_metrics, root_mean_squared_error(self.preds, self.targs))
130
131
class ExpRMSPE(RegMetrics):
132
"Computes the exponential of the root mean square error."
133
def on_epoch_end(self, last_metrics, **kwargs):
134
return add_metrics(last_metrics, exp_rmspe(self.preds, self.targs))
135
136
# Aliases
137
mse = mean_squared_error
138
mae = mean_absolute_error
139
msle = mean_squared_logarithmic_error
140
rmse = root_mean_squared_error
141
142
class ConfusionMatrix(Callback):
143
"Computes the confusion matrix."
144
145
def on_train_begin(self, **kwargs):
146
self.n_classes = 0
147
148
def on_epoch_begin(self, **kwargs):
149
self.cm = None
150
151
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
152
preds = last_output.argmax(-1).view(-1).cpu()
153
targs = last_target.cpu()
154
if self.n_classes == 0:
155
self.n_classes = last_output.shape[-1]
156
self.x = torch.arange(0, self.n_classes)
157
cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)
158
if self.cm is None: self.cm = cm
159
else: self.cm += cm
160
161
def on_epoch_end(self, **kwargs):
162
self.metric = self.cm
163
164
@dataclass
165
class CMScores(ConfusionMatrix):
166
"Base class for metrics which rely on the calculation of the precision and/or recall score."
167
average:Optional[str]="binary" # `binary`, `micro`, `macro`, `weigthed` or None
168
pos_label:int=1 # 0 or 1
169
eps:float=1e-9
170
171
def _recall(self):
172
rec = torch.diag(self.cm) / self.cm.sum(dim=1)
173
if self.average is None: return rec
174
else:
175
if self.average == "micro": weights = self._weights(avg="weighted")
176
else: weights = self._weights(avg=self.average)
177
return (rec * weights).sum()
178
179
def _precision(self):
180
prec = torch.diag(self.cm) / self.cm.sum(dim=0)
181
if self.average is None: return prec
182
else:
183
weights = self._weights(avg=self.average)
184
return (prec * weights).sum()
185
186
def _weights(self, avg:str):
187
if self.n_classes != 2 and avg == "binary":
188
avg = self.average = "macro"
189
warn("average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.")
190
if avg == "binary":
191
if self.pos_label not in (0, 1):
192
self.pos_label = 1
193
warn("Invalid value for pos_label. It has now been set to 1.")
194
if self.pos_label == 1: return Tensor([0,1])
195
else: return Tensor([1,0])
196
elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
197
elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
198
elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
199
200
201
class Recall(CMScores):
202
"Computes the Recall."
203
def on_epoch_end(self, last_metrics, **kwargs):
204
return add_metrics(last_metrics, self._recall())
205
206
class Precision(CMScores):
207
"Computes the Precision."
208
def on_epoch_end(self, last_metrics, **kwargs):
209
return add_metrics(last_metrics, self._precision())
210
211
@dataclass
212
class FBeta(CMScores):
213
"Computes the F`beta` score."
214
beta:float=2
215
216
def on_train_begin(self, **kwargs):
217
self.n_classes = 0
218
self.beta2 = self.beta ** 2
219
self.avg = self.average
220
if self.average != "micro": self.average = None
221
222
def on_epoch_end(self, last_metrics, **kwargs):
223
prec = self._precision()
224
rec = self._recall()
225
metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps)
226
metric[metric != metric] = 0 # removing potential "nan"s
227
if self.avg: metric = (self._weights(avg=self.avg) * metric).sum()
228
return add_metrics(last_metrics, metric)
229
230
def on_train_end(self, **kwargs): self.average = self.avg
231
232
@dataclass
233
class KappaScore(ConfusionMatrix):
234
"Computes the rate of agreement (Cohens Kappa)."
235
weights:Optional[str]=None # None, `linear`, or `quadratic`
236
237
def on_epoch_end(self, last_metrics, **kwargs):
238
sum0 = self.cm.sum(dim=0)
239
sum1 = self.cm.sum(dim=1)
240
expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
241
if self.weights is None:
242
w = torch.ones((self.n_classes, self.n_classes))
243
w[self.x, self.x] = 0
244
elif self.weights == "linear" or self.weights == "quadratic":
245
w = torch.zeros((self.n_classes, self.n_classes))
246
w += torch.arange(self.n_classes, dtype=torch.float)
247
w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
248
else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
249
k = torch.sum(w * self.cm) / torch.sum(w * expected)
250
return add_metrics(last_metrics, 1-k)
251
252
@dataclass
253
class MatthewsCorreff(ConfusionMatrix):
254
"Computes the Matthews correlation coefficient."
255
def on_epoch_end(self, last_metrics, **kwargs):
256
t_sum = self.cm.sum(dim=1)
257
p_sum = self.cm.sum(dim=0)
258
n_correct = torch.trace(self.cm)
259
n_samples = p_sum.sum()
260
cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum)
261
cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum)
262
cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum)
263
return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp))
264
265
class Perplexity(Callback):
266
"Perplexity metric for language models."
267
def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0
268
269
def on_batch_end(self, last_output, last_target, **kwargs):
270
self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target)
271
self.len += last_target.size(1)
272
273
def on_epoch_end(self, last_metrics, **kwargs):
274
return add_metrics(last_metrics, torch.exp(self.loss / self.len))
275
276
def auc_roc_score(input:Tensor, targ:Tensor):
277
"Computes the area under the receiver operator characteristic (ROC) curve using the trapezoid method. Restricted binary classification tasks."
278
fpr, tpr = roc_curve(input, targ)
279
d = fpr[1:] - fpr[:-1]
280
sl1, sl2 = [slice(None)], [slice(None)]
281
sl1[-1], sl2[-1] = slice(1, None), slice(None, -1)
282
return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1)
283
284
def roc_curve(input:Tensor, targ:Tensor):
285
"Computes the receiver operator characteristic (ROC) curve by determining the true positive ratio (TPR) and false positive ratio (FPR) for various classification thresholds. Restricted binary classification tasks."
286
targ = (targ == 1)
287
desc_score_indices = torch.flip(input.argsort(-1), [-1])
288
input = input[desc_score_indices]
289
targ = targ[desc_score_indices]
290
d = input[1:] - input[:-1]
291
distinct_value_indices = torch.nonzero(d).transpose(0,1)[0]
292
threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device)))
293
tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs]
294
fps = (1 + threshold_idxs - tps)
295
if tps[0] != 0 or fps[0] != 0:
296
fps = torch.cat((LongTensor([0]), fps))
297
tps = torch.cat((LongTensor([0]), tps))
298
fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1]
299
return fpr, tpr
300
301
@dataclass
302
class AUROC(Callback):
303
"Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks."
304
def on_epoch_begin(self, **kwargs):
305
self.targs, self.preds = LongTensor([]), Tensor([])
306
307
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
308
last_output = F.softmax(last_output, dim=1)[:,-1]
309
self.preds = torch.cat((self.preds, last_output.cpu()))
310
self.targs = torch.cat((self.targs, last_target.cpu().long()))
311
312
def on_epoch_end(self, last_metrics, **kwargs):
313
return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))
314
315
class MultiLabelFbeta(LearnerCallback):
316
"Computes the fbeta score for multilabel classification"
317
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
318
_order = -20
319
def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average="micro"):
320
super().__init__(learn)
321
self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \
322
eps, thresh, sigmoid, average, beta**2
323
324
def on_train_begin(self, **kwargs):
325
self.c = self.learn.data.c
326
if self.average != "none": self.learn.recorder.add_metric_names([f'{self.average}_fbeta'])
327
else: self.learn.recorder.add_metric_names([f"fbeta_{c}" for c in self.learn.data.classes])
328
329
def on_epoch_begin(self, **kwargs):
330
dvc = self.learn.data.device
331
self.tp = torch.zeros(self.c).to(dvc)
332
self.total_pred = torch.zeros(self.c).to(dvc)
333
self.total_targ = torch.zeros(self.c).to(dvc)
334
335
def on_batch_end(self, last_output, last_target, **kwargs):
336
pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte()
337
m = pred*targ
338
self.tp += m.sum(0).float()
339
self.total_pred += pred.sum(0).float()
340
self.total_targ += targ.sum(0).float()
341
342
def fbeta_score(self, precision, recall):
343
return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps)
344
345
def on_epoch_end(self, last_metrics, **kwargs):
346
self.total_pred += self.eps
347
self.total_targ += self.eps
348
if self.average == "micro":
349
precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum()
350
res = self.fbeta_score(precision, recall)
351
elif self.average == "macro":
352
res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean()
353
elif self.average == "weighted":
354
scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))
355
res = (scores*self.total_targ).sum() / self.total_targ.sum()
356
elif self.average == "none":
357
res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)))
358
else:
359
raise Exception("Choose one of the average types: [micro, macro, weighted, none]")
360
361
return add_metrics(last_metrics, res)
362
363