Path: blob/master/guides/keras_tuner/custom_tuner.py
8301 views
"""1Title: Tune hyperparameters in your custom training loop2Authors: Tom O'Malley, Haifeng Jin3Date created: 2019/10/284Last modified: 2022/01/125Description: Use `HyperModel.fit()` to tune training hyperparameters (such as batch size).6Accelerator: GPU7"""89"""shell10pip install keras-tuner -q11"""1213"""14## Introduction1516The `HyperModel` class in KerasTuner provides a convenient way to define your17search space in a reusable object. You can override `HyperModel.build()` to18define and hypertune the model itself. To hypertune the training process (e.g.19by selecting the proper batch size, number of training epochs, or data20augmentation setup), you can override `HyperModel.fit()`, where you can access:2122- The `hp` object, which is an instance of `keras_tuner.HyperParameters`23- The model built by `HyperModel.build()`2425A basic example is shown in the "tune model training" section of26[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/#tune-model-training).2728## Tuning the custom training loop2930In this guide, we will subclass the `HyperModel` class and write a custom31training loop by overriding `HyperModel.fit()`. For how to write a custom32training loop with Keras, you can refer to the guide33[Writing a training loop from scratch](https://keras.io/guides/writing_a_training_loop_from_scratch/).3435First, we import the libraries we need, and we create datasets for training and36validation. Here, we just use some random data for demonstration purposes.37"""3839import keras_tuner40import tensorflow as tf41import keras42import numpy as np4344x_train = np.random.rand(1000, 28, 28, 1)45y_train = np.random.randint(0, 10, (1000, 1))46x_val = np.random.rand(1000, 28, 28, 1)47y_val = np.random.randint(0, 10, (1000, 1))4849"""50Then, we subclass the `HyperModel` class as `MyHyperModel`. In51`MyHyperModel.build()`, we build a simple Keras model to do image52classification for 10 different classes. `MyHyperModel.fit()` accepts several53arguments. Its signature is shown below:5455```python56def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):57```5859* The `hp` argument is for defining the hyperparameters.60* The `model` argument is the model returned by `MyHyperModel.build()`.61* `x`, `y`, and `validation_data` are all custom-defined arguments. We will62pass our data to them by calling `tuner.search(x=x, y=y,63validation_data=(x_val, y_val))` later. You can define any number of them and64give custom names.65* The `callbacks` argument was intended to be used with `model.fit()`.66KerasTuner put some helpful Keras callbacks in it, for example, the callback67for checkpointing the model at its best epoch.6869We will manually call the callbacks in the custom training loop. Before we70can call them, we need to assign our model to them with the following code so71that they have access to the model for checkpointing.7273```py74for callback in callbacks:75callback.model = model76```7778In this example, we only called the `on_epoch_end()` method of the callbacks79to help us checkpoint the model. You may also call other callback methods80if needed. If you don't need to save the model, you don't need to use the81callbacks.8283In the custom training loop, we tune the batch size of the dataset as we wrap84the NumPy data into a `tf.data.Dataset`. Note that you can tune any85preprocessing steps here as well. We also tune the learning rate of the86optimizer.8788We will use the validation loss as the evaluation metric for the model. To89compute the mean validation loss, we will use `keras.metrics.Mean()`, which90averages the validation loss across the batches. We need to return the91validation loss for the tuner to make a record.92"""939495class MyHyperModel(keras_tuner.HyperModel):96def build(self, hp):97"""Builds a convolutional model."""98inputs = keras.Input(shape=(28, 28, 1))99x = keras.layers.Flatten()(inputs)100x = keras.layers.Dense(101units=hp.Choice("units", [32, 64, 128]), activation="relu"102)(x)103outputs = keras.layers.Dense(10)(x)104return keras.Model(inputs=inputs, outputs=outputs)105106def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):107# Convert the datasets to tf.data.Dataset.108batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)109train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(110batch_size111)112validation_data = tf.data.Dataset.from_tensor_slices(validation_data).batch(113batch_size114)115116# Define the optimizer.117optimizer = keras.optimizers.Adam(118hp.Float("learning_rate", 1e-4, 1e-2, sampling="log", default=1e-3)119)120loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)121122# The metric to track validation loss.123epoch_loss_metric = keras.metrics.Mean()124125# Function to run the train step.126@tf.function127def run_train_step(images, labels):128with tf.GradientTape() as tape:129logits = model(images)130loss = loss_fn(labels, logits)131# Add any regularization losses.132if model.losses:133loss += tf.math.add_n(model.losses)134gradients = tape.gradient(loss, model.trainable_variables)135optimizer.apply_gradients(zip(gradients, model.trainable_variables))136137# Function to run the validation step.138@tf.function139def run_val_step(images, labels):140logits = model(images)141loss = loss_fn(labels, logits)142# Update the metric.143epoch_loss_metric.update_state(loss)144145# Assign the model to the callbacks.146for callback in callbacks:147callback.set_model(model)148149# Record the best validation loss value150best_epoch_loss = float("inf")151152# The custom training loop.153for epoch in range(2):154print(f"Epoch: {epoch}")155156# Iterate the training data to run the training step.157for images, labels in train_ds:158run_train_step(images, labels)159160# Iterate the validation data to run the validation step.161for images, labels in validation_data:162run_val_step(images, labels)163164# Calling the callbacks after epoch.165epoch_loss = float(epoch_loss_metric.result().numpy())166for callback in callbacks:167# The "my_metric" is the objective passed to the tuner.168callback.on_epoch_end(epoch, logs={"my_metric": epoch_loss})169epoch_loss_metric.reset_state()170171print(f"Epoch loss: {epoch_loss}")172best_epoch_loss = min(best_epoch_loss, epoch_loss)173174# Return the evaluation metric value.175return best_epoch_loss176177178"""179Now, we can initialize the tuner. Here, we use `Objective("my_metric", "min")`180as our metric to be minimized. The objective name should be consistent with the181one you use as the key in the `logs` passed to the 'on_epoch_end()' method of182the callbacks. The callbacks need to use this value in the `logs` to find the183best epoch to checkpoint the model.184185"""186tuner = keras_tuner.RandomSearch(187objective=keras_tuner.Objective("my_metric", "min"),188max_trials=2,189hypermodel=MyHyperModel(),190directory="results",191project_name="custom_training",192overwrite=True,193)194195196"""197We start the search by passing the arguments we defined in the signature of198`MyHyperModel.fit()` to `tuner.search()`.199"""200201tuner.search(x=x_train, y=y_train, validation_data=(x_val, y_val))202203"""204Finally, we can retrieve the results.205"""206207best_hps = tuner.get_best_hyperparameters()[0]208print(best_hps.values)209210best_model = tuner.get_best_models()[0]211best_model.summary()212213"""214In summary, to tune the hyperparameters in your custom training loop, you just215override `HyperModel.fit()` to train the model and return the evaluation216results. With the provided callbacks, you can easily save the trained models at217their best epochs and load the best models later.218219To find out more about the basics of KerasTuner, please see220[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/).221"""222223224