Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/helpers/metrics.py
4928 views
1
import dataclasses
2
from abc import ABC
3
4
import torch
5
from labml import tracker
6
7
8
class StateModule:
9
def __init__(self):
10
pass
11
12
# def __call__(self):
13
# raise NotImplementedError
14
15
def create_state(self) -> any:
16
raise NotImplementedError
17
18
def set_state(self, data: any):
19
raise NotImplementedError
20
21
def on_epoch_start(self):
22
raise NotImplementedError
23
24
def on_epoch_end(self):
25
raise NotImplementedError
26
27
28
class Metric(StateModule, ABC):
29
def track(self):
30
pass
31
32
33
@dataclasses.dataclass
34
class AccuracyState:
35
samples: int = 0
36
correct: int = 0
37
38
def reset(self):
39
self.samples = 0
40
self.correct = 0
41
42
43
class Accuracy(Metric):
44
data: AccuracyState
45
46
def __init__(self, ignore_index: int = -1):
47
super().__init__()
48
self.ignore_index = ignore_index
49
50
def __call__(self, output: torch.Tensor, target: torch.Tensor):
51
output = output.view(-1, output.shape[-1])
52
target = target.view(-1)
53
pred = output.argmax(dim=-1)
54
mask = target == self.ignore_index
55
pred.masked_fill_(mask, self.ignore_index)
56
n_masked = mask.sum().item()
57
self.data.correct += pred.eq(target).sum().item() - n_masked
58
self.data.samples += len(target) - n_masked
59
60
def create_state(self):
61
return AccuracyState()
62
63
def set_state(self, data: any):
64
self.data = data
65
66
def on_epoch_start(self):
67
self.data.reset()
68
69
def on_epoch_end(self):
70
self.track()
71
72
def track(self):
73
if self.data.samples == 0:
74
return
75
tracker.add("accuracy.", self.data.correct / self.data.samples)
76
77
78
class AccuracyDirect(Accuracy):
79
data: AccuracyState
80
81
def __call__(self, output: torch.Tensor, target: torch.Tensor):
82
output = output.view(-1)
83
target = target.view(-1)
84
self.data.correct += output.eq(target).sum().item()
85
self.data.samples += len(target)
86
87