Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/loss_metrics.py
781 views
1
from ..torch_core import *
2
from ..callback import *
3
from ..basic_train import Learner, LearnerCallback
4
5
__all__ = ['LossMetrics']
6
7
class LossMetrics(LearnerCallback):
8
"Add `loss_func.metrics` to metrics named by `loss_func.metric_names`"
9
_order = -20 #Needs to run before the recorder
10
11
def on_train_begin(self, **kwargs):
12
"Add the metrics names to the `Recorder`."
13
self.names = ifnone(self.learn.loss_func.metric_names, [])
14
if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided')
15
self.learn.recorder.add_metric_names(self.names)
16
17
def on_epoch_begin(self, **kwargs):
18
"Initialize the metrics for this epoch."
19
self.metrics = {name:0. for name in self.names}
20
self.nums = 0
21
22
def on_batch_end(self, last_target, train, **kwargs):
23
"Update the metrics if not `train`"
24
if train: return
25
bs = last_target.size(0)
26
for name in self.names:
27
self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()
28
self.nums += bs
29
30
def on_epoch_end(self, last_metrics, **kwargs):
31
"Finish the computation and sends the result to the Recorder."
32
if not self.nums: return
33
metrics = [self.metrics[name]/self.nums for name in self.names]
34
return {'last_metrics': last_metrics+metrics}
35
36