Metrics
A metric is a function that is used to judge the performance of your model.
Metric functions are similar to loss functions, except that the results from evaluating a metric are not used when training the model. Note that you may use any loss function as a metric.
Available metrics
{{toc}}
Usage with compile()
& fit()
The compile()
method takes a metrics
argument, which is a list of metrics:
Metric values are displayed during fit()
and logged to the History
object returned by fit()
. They are also returned by model.evaluate()
.
Note that the best way to monitor your metrics during training is via TensorBoard.
To track metrics under a specific name, you can pass the name
argument to the metric constructor:
All built-in metrics may also be passed via their string identifier (in this case, default constructor argument values are used, including a default metric name):
Standalone usage
Unlike losses, metrics are stateful. You update their state using the update_state()
method, and you query the scalar metric result using the result()
method:
The internal state can be cleared via metric.reset_states()
.
Here's how you would use a metric as part of a simple custom training loop:
Creating custom metrics
As simple callables (stateless)
Much like loss functions, any callable with signature metric_fn(y_true, y_pred)
that returns an array of losses (one of sample in the input batch) can be passed to compile()
as a metric. Note that sample weighting is automatically supported for any such metric.
Here's a simple example:
In this case, the scalar metric value you are tracking during training and evaluation is the average of the per-batch metric values for all batches see during a given epoch (or during a given call to model.evaluate()
).
As subclasses of Metric
(stateful)
Not all metrics can be expressed via stateless callables, because metrics are evaluated for each batch during training and evaluation, but in some cases the average of the per-batch values is not what you are interested in.
Let's say that you want to compute AUC over a given evaluation dataset: the average of the per-batch AUC values isn't the same as the AUC over the entire dataset.
For such metrics, you're going to want to subclass the Metric
class, which can maintain a state across batches. It's easy:
Create the state variables in
__init__
Update the variables given
y_true
andy_pred
inupdate_state()
Return the scalar metric result in
result()
Clear the state in
reset_states()
Here's a simple example computing binary true positives: