Path: blob/master/guides/md/custom_train_step_in_torch.md
3293 views
Customizing what happens in fit()
with PyTorch
Author: fchollet
Date created: 2023/06/27
Last modified: 2024/08/01
Description: Overriding the training step of the Model class with PyTorch.
Introduction
When you're doing supervised learning, you can use fit()
and everything works smoothly.
When you need to take control of every little detail, you can write your own training loop entirely from scratch.
But what if you need a custom training algorithm, but you still want to benefit from the convenient features of fit()
, such as callbacks, built-in distribution support, or step fusing?
A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.
When you need to customize what fit()
does, you should override the training step function of the Model
class. This is the function that is called by fit()
for every batch of data. You will then be able to call fit()
as usual -- and it will be running your own learning algorithm.
Note that this pattern does not prevent you from building models with the Functional API. You can do this whether you're building Sequential
models, Functional API models, or subclassed models.
Let's see how that works.
Setup
A first simple example
Let's start from a simple example:
We create a new class that subclasses
keras.Model
.We just override the method
train_step(self, data)
.We return a dictionary mapping metric names (including the loss) to their current value.
The input argument data
is what gets passed to fit as training data:
If you pass NumPy arrays, by calling
fit(x, y, ...)
, thendata
will be the tuple(x, y)
If you pass a
torch.utils.data.DataLoader
or atf.data.Dataset
, by callingfit(dataset, ...)
, thendata
will be what gets yielded bydataset
at each batch.
In the body of the train_step()
method, we implement a regular training update, similar to what you are already familiar with. Importantly, we compute the loss via self.compute_loss()
, which wraps the loss(es) function(s) that were passed to compile()
.
Similarly, we call metric.update_state(y, y_pred)
on metrics from self.metrics
, to update the state of the metrics that were passed in compile()
, and we query results from self.metrics
at the end to retrieve their current value.
Let's try this out:
<keras.src.callbacks.history.History at 0x7f48a3255710>
<keras.src.callbacks.history.History at 0x7f48975ccbd0>
<keras.src.callbacks.history.History at 0x7f48975d7bd0>
1/32 [37m━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.8706 - loss: 0.9344
Here's a feature-complete GAN class, overriding compile()
to use its own signature, and implementing the entire GAN algorithm in 17 lines in train_step
:
Let's test-drive it: