Path: blob/master/labml_nn/helpers/metrics.py
4928 views
import dataclasses1from abc import ABC23import torch4from labml import tracker567class StateModule:8def __init__(self):9pass1011# def __call__(self):12# raise NotImplementedError1314def create_state(self) -> any:15raise NotImplementedError1617def set_state(self, data: any):18raise NotImplementedError1920def on_epoch_start(self):21raise NotImplementedError2223def on_epoch_end(self):24raise NotImplementedError252627class Metric(StateModule, ABC):28def track(self):29pass303132@dataclasses.dataclass33class AccuracyState:34samples: int = 035correct: int = 03637def reset(self):38self.samples = 039self.correct = 0404142class Accuracy(Metric):43data: AccuracyState4445def __init__(self, ignore_index: int = -1):46super().__init__()47self.ignore_index = ignore_index4849def __call__(self, output: torch.Tensor, target: torch.Tensor):50output = output.view(-1, output.shape[-1])51target = target.view(-1)52pred = output.argmax(dim=-1)53mask = target == self.ignore_index54pred.masked_fill_(mask, self.ignore_index)55n_masked = mask.sum().item()56self.data.correct += pred.eq(target).sum().item() - n_masked57self.data.samples += len(target) - n_masked5859def create_state(self):60return AccuracyState()6162def set_state(self, data: any):63self.data = data6465def on_epoch_start(self):66self.data.reset()6768def on_epoch_end(self):69self.track()7071def track(self):72if self.data.samples == 0:73return74tracker.add("accuracy.", self.data.correct / self.data.samples)757677class AccuracyDirect(Accuracy):78data: AccuracyState7980def __call__(self, output: torch.Tensor, target: torch.Tensor):81output = output.view(-1)82target = target.view(-1)83self.data.correct += output.eq(target).sum().item()84self.data.samples += len(target)858687