Path: blob/master/guides/custom_train_step_in_jax.py
3273 views
"""1Title: Customizing what happens in `fit()` with JAX2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/274Last modified: 2023/06/275Description: Overriding the training step of the Model class with JAX.6Accelerator: GPU7"""89"""10## Introduction1112When you're doing supervised learning, you can use `fit()` and everything works13smoothly.1415When you need to take control of every little detail, you can write your own training16loop entirely from scratch.1718But what if you need a custom training algorithm, but you still want to benefit from19the convenient features of `fit()`, such as callbacks, built-in distribution support,20or step fusing?2122A core principle of Keras is **progressive disclosure of complexity**. You should23always be able to get into lower-level workflows in a gradual way. You shouldn't fall24off a cliff if the high-level functionality doesn't exactly match your use case. You25should be able to gain more control over the small details while retaining a26commensurate amount of high-level convenience.2728When you need to customize what `fit()` does, you should **override the training step29function of the `Model` class**. This is the function that is called by `fit()` for30every batch of data. You will then be able to call `fit()` as usual -- and it will be31running your own learning algorithm.3233Note that this pattern does not prevent you from building models with the Functional34API. You can do this whether you're building `Sequential` models, Functional API35models, or subclassed models.3637Let's see how that works.38"""3940"""41## Setup42"""4344import os4546# This guide can only be run with the JAX backend.47os.environ["KERAS_BACKEND"] = "jax"4849import jax50import keras51import numpy as np5253"""54## A first simple example5556Let's start from a simple example:5758- We create a new class that subclasses `keras.Model`.59- We implement a fully-stateless `compute_loss_and_updates()` method60to compute the loss as well as the updated values for the non-trainable61variables of the model. Internally, it calls `stateless_call()` and62the built-in `compute_loss()`.63- We implement a fully-stateless `train_step()` method to compute current64metric values (including the loss) as well as updated values for the65trainable variables, the optimizer variables, and the metric variables.6667Note that you can also take into account the `sample_weight` argument by:6869- Unpacking the data as `x, y, sample_weight = data`70- Passing `sample_weight` to `compute_loss()`71- Passing `sample_weight` alongside `y` and `y_pred`72to metrics in `stateless_update_state()`73"""747576class CustomModel(keras.Model):77def compute_loss_and_updates(78self,79trainable_variables,80non_trainable_variables,81x,82y,83training=False,84):85y_pred, non_trainable_variables = self.stateless_call(86trainable_variables,87non_trainable_variables,88x,89training=training,90)91loss = self.compute_loss(x, y, y_pred)92return loss, (y_pred, non_trainable_variables)9394def train_step(self, state, data):95(96trainable_variables,97non_trainable_variables,98optimizer_variables,99metrics_variables,100) = state101x, y = data102103# Get the gradient function.104grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)105106# Compute the gradients.107(loss, (y_pred, non_trainable_variables)), grads = grad_fn(108trainable_variables,109non_trainable_variables,110x,111y,112training=True,113)114115# Update trainable variables and optimizer variables.116(117trainable_variables,118optimizer_variables,119) = self.optimizer.stateless_apply(120optimizer_variables, grads, trainable_variables121)122123# Update metrics.124new_metrics_vars = []125logs = {}126for metric in self.metrics:127this_metric_vars = metrics_variables[128len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)129]130if metric.name == "loss":131this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)132else:133this_metric_vars = metric.stateless_update_state(134this_metric_vars, y, y_pred135)136logs[metric.name] = metric.stateless_result(this_metric_vars)137new_metrics_vars += this_metric_vars138139# Return metric logs and updated state variables.140state = (141trainable_variables,142non_trainable_variables,143optimizer_variables,144new_metrics_vars,145)146return logs, state147148149"""150Let's try this out:151"""152153# Construct and compile an instance of CustomModel154inputs = keras.Input(shape=(32,))155outputs = keras.layers.Dense(1)(inputs)156model = CustomModel(inputs, outputs)157model.compile(optimizer="adam", loss="mse", metrics=["mae"])158159# Just use `fit` as usual160x = np.random.random((1000, 32))161y = np.random.random((1000, 1))162model.fit(x, y, epochs=3)163164165"""166## Going lower-level167168Naturally, you could just skip passing a loss function in `compile()`, and instead do169everything *manually* in `train_step`. Likewise for metrics.170171Here's a lower-level example, that only uses `compile()` to configure the optimizer:172"""173174175class CustomModel(keras.Model):176def __init__(self, *args, **kwargs):177super().__init__(*args, **kwargs)178self.loss_tracker = keras.metrics.Mean(name="loss")179self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")180self.loss_fn = keras.losses.MeanSquaredError()181182def compute_loss_and_updates(183self,184trainable_variables,185non_trainable_variables,186x,187y,188training=False,189):190y_pred, non_trainable_variables = self.stateless_call(191trainable_variables,192non_trainable_variables,193x,194training=training,195)196loss = self.loss_fn(y, y_pred)197return loss, (y_pred, non_trainable_variables)198199def train_step(self, state, data):200(201trainable_variables,202non_trainable_variables,203optimizer_variables,204metrics_variables,205) = state206x, y = data207208# Get the gradient function.209grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)210211# Compute the gradients.212(loss, (y_pred, non_trainable_variables)), grads = grad_fn(213trainable_variables,214non_trainable_variables,215x,216y,217training=True,218)219220# Update trainable variables and optimizer variables.221(222trainable_variables,223optimizer_variables,224) = self.optimizer.stateless_apply(225optimizer_variables, grads, trainable_variables226)227228# Update metrics.229loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]230mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]231232loss_tracker_vars = self.loss_tracker.stateless_update_state(233loss_tracker_vars, loss234)235mae_metric_vars = self.mae_metric.stateless_update_state(236mae_metric_vars, y, y_pred237)238239logs = {}240logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(241loss_tracker_vars242)243logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)244245new_metrics_vars = loss_tracker_vars + mae_metric_vars246247# Return metric logs and updated state variables.248state = (249trainable_variables,250non_trainable_variables,251optimizer_variables,252new_metrics_vars,253)254return logs, state255256@property257def metrics(self):258# We list our `Metric` objects here so that `reset_states()` can be259# called automatically at the start of each epoch260# or at the start of `evaluate()`.261return [self.loss_tracker, self.mae_metric]262263264# Construct an instance of CustomModel265inputs = keras.Input(shape=(32,))266outputs = keras.layers.Dense(1)(inputs)267model = CustomModel(inputs, outputs)268269# We don't pass a loss or metrics here.270model.compile(optimizer="adam")271272# Just use `fit` as usual -- you can use callbacks, etc.273x = np.random.random((1000, 32))274y = np.random.random((1000, 1))275model.fit(x, y, epochs=5)276277278"""279## Providing your own evaluation step280281What if you want to do the same for calls to `model.evaluate()`? Then you would282override `test_step` in exactly the same way. Here's what it looks like:283"""284285286class CustomModel(keras.Model):287def test_step(self, state, data):288# Unpack the data.289x, y = data290(291trainable_variables,292non_trainable_variables,293metrics_variables,294) = state295296# Compute predictions and loss.297y_pred, non_trainable_variables = self.stateless_call(298trainable_variables,299non_trainable_variables,300x,301training=False,302)303loss = self.compute_loss(x, y, y_pred)304305# Update metrics.306new_metrics_vars = []307for metric in self.metrics:308this_metric_vars = metrics_variables[309len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)310]311if metric.name == "loss":312this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)313else:314this_metric_vars = metric.stateless_update_state(315this_metric_vars, y, y_pred316)317logs = metric.stateless_result(this_metric_vars)318new_metrics_vars += this_metric_vars319320# Return metric logs and updated state variables.321state = (322trainable_variables,323non_trainable_variables,324new_metrics_vars,325)326return logs, state327328329# Construct an instance of CustomModel330inputs = keras.Input(shape=(32,))331outputs = keras.layers.Dense(1)(inputs)332model = CustomModel(inputs, outputs)333model.compile(loss="mse", metrics=["mae"])334335# Evaluate with our custom test_step336x = np.random.random((1000, 32))337y = np.random.random((1000, 1))338model.evaluate(x, y)339340341"""342That's it!343"""344345346