Path: blob/master/guides/distributed_training_with_jax.py
3273 views
"""1Title: Multi-GPU distributed training with JAX2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/07/114Last modified: 2023/07/115Description: Guide to multi-GPU/TPU training for Keras models with JAX.6Accelerator: GPU7"""89"""10## Introduction1112There are generally two ways to distribute computation across multiple devices:1314**Data parallelism**, where a single model gets replicated on multiple devices or15multiple machines. Each of them processes different batches of data, then they merge16their results. There exist many variants of this setup, that differ in how the different17model replicas merge results, in whether they stay in sync at every batch or whether they18are more loosely coupled, etc.1920**Model parallelism**, where different parts of a single model run on different devices,21processing a single batch of data together. This works best with models that have a22naturally-parallel architecture, such as models that feature multiple branches.2324This guide focuses on data parallelism, in particular **synchronous data parallelism**,25where the different replicas of the model stay in sync after each batch they process.26Synchronicity keeps the model convergence behavior identical to what you would see for27single-device training.2829Specifically, this guide teaches you how to use `jax.sharding` APIs to train Keras30models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16)31installed on a single machine (single host, multi-device training). This is the32most common setup for researchers and small-scale industry workflows.33"""3435"""36## Setup3738Let's start by defining the function that creates the model that we will train,39and the function that creates the dataset we will train on (MNIST in this case).40"""4142import os4344os.environ["KERAS_BACKEND"] = "jax"4546import jax47import numpy as np48import tensorflow as tf49import keras5051from jax.experimental import mesh_utils52from jax.sharding import Mesh53from jax.sharding import NamedSharding54from jax.sharding import PartitionSpec as P555657def get_model():58# Make a simple convnet with batch normalization and dropout.59inputs = keras.Input(shape=(28, 28, 1))60x = keras.layers.Rescaling(1.0 / 255.0)(inputs)61x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(62x63)64x = keras.layers.BatchNormalization(scale=False, center=True)(x)65x = keras.layers.ReLU()(x)66x = keras.layers.Conv2D(67filters=24,68kernel_size=6,69use_bias=False,70strides=2,71)(x)72x = keras.layers.BatchNormalization(scale=False, center=True)(x)73x = keras.layers.ReLU()(x)74x = keras.layers.Conv2D(75filters=32,76kernel_size=6,77padding="same",78strides=2,79name="large_k",80)(x)81x = keras.layers.BatchNormalization(scale=False, center=True)(x)82x = keras.layers.ReLU()(x)83x = keras.layers.GlobalAveragePooling2D()(x)84x = keras.layers.Dense(256, activation="relu")(x)85x = keras.layers.Dropout(0.5)(x)86outputs = keras.layers.Dense(10)(x)87model = keras.Model(inputs, outputs)88return model899091def get_datasets():92# Load the data and split it between train and test sets93(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()9495# Scale images to the [0, 1] range96x_train = x_train.astype("float32")97x_test = x_test.astype("float32")98# Make sure images have shape (28, 28, 1)99x_train = np.expand_dims(x_train, -1)100x_test = np.expand_dims(x_test, -1)101print("x_train shape:", x_train.shape)102print(x_train.shape[0], "train samples")103print(x_test.shape[0], "test samples")104105# Create TF Datasets106train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))107eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))108return train_data, eval_data109110111"""112## Single-host, multi-device synchronous training113114In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16).115Each device will run a copy of your model (called a **replica**). For simplicity, in116what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.117118**How it works**119120At each step of training:121122- The current batch of data (called **global batch**) is split into 8 different123sub-batches (called **local batches**). For instance, if the global batch has 512124samples, each of the 8 local batches will have 64 samples.125- Each of the 8 replicas independently processes a local batch: they run a forward pass,126then a backward pass, outputting the gradient of the weights with respect to the loss of127the model on the local batch.128- The weight updates originating from local gradients are efficiently merged across the 8129replicas. Because this is done at the end of every step, the replicas always stay in130sync.131132In practice, the process of synchronously updating the weights of the model replicas is133handled at the level of each individual weight variable. This is done through a using134a `jax.sharding.NamedSharding` that is configured to replicate the variables.135136**How to use it**137138To do single-host, multi-device synchronous training with a Keras model, you139would use the `jax.sharding` features. Here's how it works:140141- We first create a device mesh using `mesh_utils.create_device_mesh`.142- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and143`jax.sharding.PartitionSpec` to define how to partition JAX arrays.144- We specify that we want to replicate the model and optimizer variables145across all devices by using a spec with no axis.146- We specify that we want to shard the data across devices by using a spec147that splits along the batch dimension.148- We use `jax.device_put` to replicate the model and optimizer variables across149devices. This happens once at the beginning.150- In the training loop, for each batch that we process, we use `jax.device_put`151to split the batch across devices before invoking the train step.152153Here's the flow, where each step is split into its own utility function:154"""155156# Config157num_epochs = 2158batch_size = 64159160train_data, eval_data = get_datasets()161train_data = train_data.batch(batch_size, drop_remainder=True)162163model = get_model()164optimizer = keras.optimizers.Adam(1e-3)165loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)166167# Initialize all state with .build()168(one_batch, one_batch_labels) = next(iter(train_data))169model.build(one_batch)170optimizer.build(model.trainable_variables)171172173# This is the loss function that will be differentiated.174# Keras provides a pure functional forward pass: model.stateless_call175def compute_loss(trainable_variables, non_trainable_variables, x, y):176y_pred, updated_non_trainable_variables = model.stateless_call(177trainable_variables, non_trainable_variables, x, training=True178)179loss_value = loss(y, y_pred)180return loss_value, updated_non_trainable_variables181182183# Function to compute gradients184compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)185186187# Training step, Keras provides a pure functional optimizer.stateless_apply188@jax.jit189def train_step(train_state, x, y):190trainable_variables, non_trainable_variables, optimizer_variables = train_state191(loss_value, non_trainable_variables), grads = compute_gradients(192trainable_variables, non_trainable_variables, x, y193)194195trainable_variables, optimizer_variables = optimizer.stateless_apply(196optimizer_variables, grads, trainable_variables197)198199return loss_value, (200trainable_variables,201non_trainable_variables,202optimizer_variables,203)204205206# Replicate the model and optimizer variable on all devices207def get_replicated_train_state(devices):208# All variables will be replicated on all devices209var_mesh = Mesh(devices, axis_names=("_"))210# In NamedSharding, axes not mentioned are replicated (all axes here)211var_replication = NamedSharding(var_mesh, P())212213# Apply the distribution settings to the model variables214trainable_variables = jax.device_put(model.trainable_variables, var_replication)215non_trainable_variables = jax.device_put(216model.non_trainable_variables, var_replication217)218optimizer_variables = jax.device_put(optimizer.variables, var_replication)219220# Combine all state in a tuple221return (trainable_variables, non_trainable_variables, optimizer_variables)222223224num_devices = len(jax.local_devices())225print(f"Running on {num_devices} devices: {jax.local_devices()}")226devices = mesh_utils.create_device_mesh((num_devices,))227228# Data will be split along the batch axis229data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh230data_sharding = NamedSharding(231data_mesh,232P(233"batch",234),235) # naming axes of the sharded partition236237# Display data sharding238x, y = next(iter(train_data))239sharded_x = jax.device_put(x.numpy(), data_sharding)240print("Data sharding")241jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))242243train_state = get_replicated_train_state(devices)244245# Custom training loop246for epoch in range(num_epochs):247data_iter = iter(train_data)248for data in data_iter:249x, y = data250sharded_x = jax.device_put(x.numpy(), data_sharding)251loss_value, train_state = train_step(train_state, sharded_x, y.numpy())252print("Epoch", epoch, "loss:", loss_value)253254# Post-processing model state update to write them back into the model255trainable_variables, non_trainable_variables, optimizer_variables = train_state256for variable, value in zip(model.trainable_variables, trainable_variables):257variable.assign(value)258for variable, value in zip(model.non_trainable_variables, non_trainable_variables):259variable.assign(value)260261"""262That's it!263"""264265266