Path: blob/master/guides/ipynb/custom_train_step_in_jax.ipynb
3283 views
Customizing what happens in fit()
with JAX
Author: fchollet
Date created: 2023/06/27
Last modified: 2023/06/27
Description: Overriding the training step of the Model class with JAX.
Introduction
When you're doing supervised learning, you can use fit()
and everything works smoothly.
When you need to take control of every little detail, you can write your own training loop entirely from scratch.
But what if you need a custom training algorithm, but you still want to benefit from the convenient features of fit()
, such as callbacks, built-in distribution support, or step fusing?
A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.
When you need to customize what fit()
does, you should override the training step function of the Model
class. This is the function that is called by fit()
for every batch of data. You will then be able to call fit()
as usual -- and it will be running your own learning algorithm.
Note that this pattern does not prevent you from building models with the Functional API. You can do this whether you're building Sequential
models, Functional API models, or subclassed models.
Let's see how that works.
Setup
A first simple example
Let's start from a simple example:
We create a new class that subclasses
keras.Model
.We implement a fully-stateless
compute_loss_and_updates()
method to compute the loss as well as the updated values for the non-trainable variables of the model. Internally, it callsstateless_call()
and the built-incompute_loss()
.We implement a fully-stateless
train_step()
method to compute current metric values (including the loss) as well as updated values for the trainable variables, the optimizer variables, and the metric variables.
Note that you can also take into account the sample_weight
argument by:
Unpacking the data as
x, y, sample_weight = data
Passing
sample_weight
tocompute_loss()
Passing
sample_weight
alongsidey
andy_pred
to metrics instateless_update_state()
Let's try this out:
Going lower-level
Naturally, you could just skip passing a loss function in compile()
, and instead do everything manually in train_step
. Likewise for metrics.
Here's a lower-level example, that only uses compile()
to configure the optimizer:
Providing your own evaluation step
What if you want to do the same for calls to model.evaluate()
? Then you would override test_step
in exactly the same way. Here's what it looks like:
That's it!