Path: blob/master/guides/writing_a_custom_training_loop_in_jax.py
3273 views
"""1Title: Writing a training loop from scratch in JAX2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/254Last modified: 2023/06/255Description: Writing low-level training & evaluation loops in JAX.6Accelerator: None7"""89"""10## Setup11"""1213import os1415# This guide can only be run with the jax backend.16os.environ["KERAS_BACKEND"] = "jax"1718import jax1920# We import TF so we can use tf.data.21import tensorflow as tf22import keras23import numpy as np2425"""26## Introduction2728Keras provides default training and evaluation loops, `fit()` and `evaluate()`.29Their usage is covered in the guide30[Training & evaluation with the built-in methods](/guides/training_with_built_in_methods/).3132If you want to customize the learning algorithm of your model while still leveraging33the convenience of `fit()`34(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and35implement your own `train_step()` method, which36is called repeatedly during `fit()`.3738Now, if you want very low-level control over training & evaluation, you should write39your own training & evaluation loops from scratch. This is what this guide is about.40"""4142"""43## A first end-to-end example4445To write a custom training loop, we need the following ingredients:4647- A model to train, of course.48- An optimizer. You could either use an optimizer from `keras.optimizers`, or49one from the `optax` package.50- A loss function.51- A dataset. The standard in the JAX ecosystem is to load data via `tf.data`,52so that's what we'll use.5354Let's line them up.5556First, let's get the model and the MNIST dataset:57"""585960def get_model():61inputs = keras.Input(shape=(784,), name="digits")62x1 = keras.layers.Dense(64, activation="relu")(inputs)63x2 = keras.layers.Dense(64, activation="relu")(x1)64outputs = keras.layers.Dense(10, name="predictions")(x2)65model = keras.Model(inputs=inputs, outputs=outputs)66return model676869model = get_model()7071# Prepare the training dataset.72batch_size = 3273(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()74x_train = np.reshape(x_train, (-1, 784)).astype("float32")75x_test = np.reshape(x_test, (-1, 784)).astype("float32")76y_train = keras.utils.to_categorical(y_train)77y_test = keras.utils.to_categorical(y_test)7879# Reserve 10,000 samples for validation.80x_val = x_train[-10000:]81y_val = y_train[-10000:]82x_train = x_train[:-10000]83y_train = y_train[:-10000]8485# Prepare the training dataset.86train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))87train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)8889# Prepare the validation dataset.90val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))91val_dataset = val_dataset.batch(batch_size)9293"""94Next, here's the loss function and the optimizer.95We'll use a Keras optimizer in this case.96"""9798# Instantiate a loss function.99loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)100101# Instantiate an optimizer.102optimizer = keras.optimizers.Adam(learning_rate=1e-3)103104"""105### Getting gradients in JAX106107Let's train our model using mini-batch gradient with a custom training loop.108109In JAX, gradients are computed via *metaprogramming*: you call the `jax.grad` (or110`jax.value_and_grad` on a function in order to create a gradient-computing function111for that first function.112113So the first thing we need is a function that returns the loss value.114That's the function we'll use to generate the gradient function. Something like this:115116```python117def compute_loss(x, y):118...119return loss120```121122Once you have such a function, you can compute gradients via metaprogramming as such:123124```python125grad_fn = jax.grad(compute_loss)126grads = grad_fn(x, y)127```128129Typically, you don't just want to get the gradient values, you also want to get130the loss value. You can do this by using `jax.value_and_grad` instead of `jax.grad`:131132```python133grad_fn = jax.value_and_grad(compute_loss)134loss, grads = grad_fn(x, y)135```136137### JAX computation is purely stateless138139In JAX, everything must be a stateless function -- so our loss computation function140must be stateless as well. That means that all Keras variables (e.g. weight tensors)141must be passed as function inputs, and any variable that has been updated during the142forward pass must be returned as function output. The function have no side effect.143144During the forward pass, the non-trainable variables of a Keras model might get145updated. These variables could be, for instance, RNG seed state variables or146BatchNormalization statistics. We're going to need to return those. So we need147something like this:148149```python150def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):151...152return loss, non_trainable_variables153```154155Once you have such a function, you can get the gradient function by156specifying `has_aux` in `value_and_grad`: it tells JAX that the loss157computation function returns more outputs than just the loss. Note that the loss158should always be the first output.159160```python161grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)162(loss, non_trainable_variables), grads = grad_fn(163trainable_variables, non_trainable_variables, x, y164)165```166167Now that we have established the basics,168let's implement this `compute_loss_and_updates` function.169Keras models have a `stateless_call` method which will come in handy here.170It works just like `model.__call__`, but it requires you to explicitly171pass the value of all the variables in the model, and it returns not just172the `__call__` outputs but also the (potentially updated) non-trainable173variables.174"""175176177def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):178y_pred, non_trainable_variables = model.stateless_call(179trainable_variables, non_trainable_variables, x, training=True180)181loss = loss_fn(y, y_pred)182return loss, non_trainable_variables183184185"""186Let's get the gradient function:187"""188189grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)190191"""192### The training step function193194Next, let's implement the end-to-end training step, the function195that will both run the forward pass, compute the loss, compute the gradients,196but also use the optimizer to update the trainable variables. This function197also needs to be stateless, so it will get as input a `state` tuple that198includes every state element we're going to use:199200- `trainable_variables` and `non_trainable_variables`: the model's variables.201- `optimizer_variables`: the optimizer's state variables,202such as momentum accumulators.203204To update the trainable variables, we use the optimizer's stateless method205`stateless_apply`. It's equivalent to `optimizer.apply()`, but it requires206always passing `trainable_variables` and `optimizer_variables`. It returns207both the updated trainable variables and the updated optimizer_variables.208"""209210211def train_step(state, data):212trainable_variables, non_trainable_variables, optimizer_variables = state213x, y = data214(loss, non_trainable_variables), grads = grad_fn(215trainable_variables, non_trainable_variables, x, y216)217trainable_variables, optimizer_variables = optimizer.stateless_apply(218optimizer_variables, grads, trainable_variables219)220# Return updated state221return loss, (222trainable_variables,223non_trainable_variables,224optimizer_variables,225)226227228"""229### Make it fast with `jax.jit`230231By default, JAX operations run eagerly,232just like in TensorFlow eager mode and PyTorch eager mode.233And just like TensorFlow eager mode and PyTorch eager mode, it's pretty slow234-- eager mode is better used as a debugging environment, not as a way to do235any actual work. So let's make our `train_step` fast by compiling it.236237When you have a stateless JAX function, you can compile it to XLA via the238`@jax.jit` decorator. It will get traced during its first execution, and in239subsequent executions you will be executing the traced graph (this is just240like `@tf.function(jit_compile=True)`. Let's try it:241"""242243244@jax.jit245def train_step(state, data):246trainable_variables, non_trainable_variables, optimizer_variables = state247x, y = data248(loss, non_trainable_variables), grads = grad_fn(249trainable_variables, non_trainable_variables, x, y250)251trainable_variables, optimizer_variables = optimizer.stateless_apply(252optimizer_variables, grads, trainable_variables253)254# Return updated state255return loss, (256trainable_variables,257non_trainable_variables,258optimizer_variables,259)260261262"""263We're now ready to train our model. The training loop itself264is trivial: we just repeatedly call `loss, state = train_step(state, data)`.265266Note:267268- We convert the TF tensors yielded by the `tf.data.Dataset` to NumPy269before passing them to our JAX function.270- All variables must be built beforehand:271the model must be built and the optimizer must be built. Since we're using a272Functional API model, it's already built, but if it were a subclassed model273you'd need to call it on a batch of data to build it.274"""275276# Build optimizer variables.277optimizer.build(model.trainable_variables)278279trainable_variables = model.trainable_variables280non_trainable_variables = model.non_trainable_variables281optimizer_variables = optimizer.variables282state = trainable_variables, non_trainable_variables, optimizer_variables283284# Training loop285for step, data in enumerate(train_dataset):286data = (data[0].numpy(), data[1].numpy())287loss, state = train_step(state, data)288# Log every 100 batches.289if step % 100 == 0:290print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")291print(f"Seen so far: {(step + 1) * batch_size} samples")292293"""294A key thing to notice here is that the loop is entirely stateless -- the variables295attached to the model (`model.weights`) are never getting updated during the loop.296Their new values are only stored in the `state` tuple. That means that at some point,297before saving the model, you should be attaching the new variable values back to the model.298299Just call `variable.assign(new_value)` on each model variable you want to update:300"""301302trainable_variables, non_trainable_variables, optimizer_variables = state303for variable, value in zip(model.trainable_variables, trainable_variables):304variable.assign(value)305for variable, value in zip(model.non_trainable_variables, non_trainable_variables):306variable.assign(value)307308"""309## Low-level handling of metrics310311Let's add metrics monitoring to this basic training loop.312313You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training314loops written from scratch. Here's the flow:315316- Instantiate the metric at the start of the loop317- Include `metric_variables` in the `train_step` arguments318and `compute_loss_and_updates` arguments.319- Call `metric.stateless_update_state()` in the `compute_loss_and_updates` function.320It's equivalent to `update_state()` -- only stateless.321- When you need to display the current value of the metric, outside the `train_step`322(in the eager scope), attach the new metric variable values to the metric object323and vall `metric.result()`.324- Call `metric.reset_state()` when you need to clear the state of the metric325(typically at the end of an epoch)326327Let's use this knowledge to compute `CategoricalAccuracy` on training and328validation data at the end of training:329"""330331# Get a fresh model332model = get_model()333334# Instantiate an optimizer to train the model.335optimizer = keras.optimizers.Adam(learning_rate=1e-3)336# Instantiate a loss function.337loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)338339# Prepare the metrics.340train_acc_metric = keras.metrics.CategoricalAccuracy()341val_acc_metric = keras.metrics.CategoricalAccuracy()342343344def compute_loss_and_updates(345trainable_variables, non_trainable_variables, metric_variables, x, y346):347y_pred, non_trainable_variables = model.stateless_call(348trainable_variables, non_trainable_variables, x349)350loss = loss_fn(y, y_pred)351metric_variables = train_acc_metric.stateless_update_state(352metric_variables, y, y_pred353)354return loss, (non_trainable_variables, metric_variables)355356357grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)358359360@jax.jit361def train_step(state, data):362(363trainable_variables,364non_trainable_variables,365optimizer_variables,366metric_variables,367) = state368x, y = data369(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(370trainable_variables, non_trainable_variables, metric_variables, x, y371)372trainable_variables, optimizer_variables = optimizer.stateless_apply(373optimizer_variables, grads, trainable_variables374)375# Return updated state376return loss, (377trainable_variables,378non_trainable_variables,379optimizer_variables,380metric_variables,381)382383384"""385We'll also prepare an evaluation step function:386"""387388389@jax.jit390def eval_step(state, data):391trainable_variables, non_trainable_variables, metric_variables = state392x, y = data393y_pred, non_trainable_variables = model.stateless_call(394trainable_variables, non_trainable_variables, x395)396loss = loss_fn(y, y_pred)397metric_variables = val_acc_metric.stateless_update_state(398metric_variables, y, y_pred399)400return loss, (401trainable_variables,402non_trainable_variables,403metric_variables,404)405406407"""408Here are our loops:409"""410411# Build optimizer variables.412optimizer.build(model.trainable_variables)413414trainable_variables = model.trainable_variables415non_trainable_variables = model.non_trainable_variables416optimizer_variables = optimizer.variables417metric_variables = train_acc_metric.variables418state = (419trainable_variables,420non_trainable_variables,421optimizer_variables,422metric_variables,423)424425# Training loop426for step, data in enumerate(train_dataset):427data = (data[0].numpy(), data[1].numpy())428loss, state = train_step(state, data)429# Log every 100 batches.430if step % 100 == 0:431print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")432_, _, _, metric_variables = state433for variable, value in zip(train_acc_metric.variables, metric_variables):434variable.assign(value)435print(f"Training accuracy: {train_acc_metric.result()}")436print(f"Seen so far: {(step + 1) * batch_size} samples")437438metric_variables = val_acc_metric.variables439(440trainable_variables,441non_trainable_variables,442optimizer_variables,443metric_variables,444) = state445state = trainable_variables, non_trainable_variables, metric_variables446447# Eval loop448for step, data in enumerate(val_dataset):449data = (data[0].numpy(), data[1].numpy())450loss, state = eval_step(state, data)451# Log every 100 batches.452if step % 100 == 0:453print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")454_, _, metric_variables = state455for variable, value in zip(val_acc_metric.variables, metric_variables):456variable.assign(value)457print(f"Validation accuracy: {val_acc_metric.result()}")458print(f"Seen so far: {(step + 1) * batch_size} samples")459460"""461## Low-level handling of losses tracked by the model462463Layers & models recursively track any losses created during the forward pass464by layers that call `self.add_loss(value)`. The resulting list of scalar loss465values are available via the property `model.losses`466at the end of the forward pass.467468If you want to be using these loss components, you should sum them469and add them to the main loss in your training step.470471Consider this layer, that creates an activity regularization loss:472"""473474475class ActivityRegularizationLayer(keras.layers.Layer):476def call(self, inputs):477self.add_loss(1e-2 * jax.numpy.sum(inputs))478return inputs479480481"""482Let's build a really simple model that uses it:483"""484485inputs = keras.Input(shape=(784,), name="digits")486x = keras.layers.Dense(64, activation="relu")(inputs)487# Insert activity regularization as a layer488x = ActivityRegularizationLayer()(x)489x = keras.layers.Dense(64, activation="relu")(x)490outputs = keras.layers.Dense(10, name="predictions")(x)491492model = keras.Model(inputs=inputs, outputs=outputs)493494"""495Here's what our `compute_loss_and_updates` function should look like now:496497- Pass `return_losses=True` to `model.stateless_call()`.498- Sum the resulting `losses` and add them to the main loss.499"""500501502def compute_loss_and_updates(503trainable_variables, non_trainable_variables, metric_variables, x, y504):505y_pred, non_trainable_variables, losses = model.stateless_call(506trainable_variables, non_trainable_variables, x, return_losses=True507)508loss = loss_fn(y, y_pred)509if losses:510loss += jax.numpy.sum(losses)511metric_variables = train_acc_metric.stateless_update_state(512metric_variables, y, y_pred513)514return loss, non_trainable_variables, metric_variables515516517"""518That's it!519"""520521522