Path: blob/master/guides/keras_tuner/custom_tuner.py
3293 views
"""1Title: Tune hyperparameters in your custom training loop2Authors: Tom O'Malley, Haifeng Jin3Date created: 2019/10/284Last modified: 2022/01/125Description: Use `HyperModel.fit()` to tune training hyperparameters (such as batch size).6Accelerator: GPU7"""89"""shell10pip install keras-tuner -q11"""1213"""14## Introduction1516The `HyperModel` class in KerasTuner provides a convenient way to define your17search space in a reusable object. You can override `HyperModel.build()` to18define and hypertune the model itself. To hypertune the training process (e.g.19by selecting the proper batch size, number of training epochs, or data20augmentation setup), you can override `HyperModel.fit()`, where you can access:2122- The `hp` object, which is an instance of `keras_tuner.HyperParameters`23- The model built by `HyperModel.build()`2425A basic example is shown in the "tune model training" section of26[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/#tune-model-training).2728## Tuning the custom training loop2930In this guide, we will subclass the `HyperModel` class and write a custom31training loop by overriding `HyperModel.fit()`. For how to write a custom32training loop with Keras, you can refer to the guide33[Writing a training loop from scratch](https://keras.io/guides/writing_a_training_loop_from_scratch/).3435First, we import the libraries we need, and we create datasets for training and36validation. Here, we just use some random data for demonstration purposes.37"""3839import keras_tuner40import tensorflow as tf41import keras42import numpy as np434445x_train = np.random.rand(1000, 28, 28, 1)46y_train = np.random.randint(0, 10, (1000, 1))47x_val = np.random.rand(1000, 28, 28, 1)48y_val = np.random.randint(0, 10, (1000, 1))4950"""51Then, we subclass the `HyperModel` class as `MyHyperModel`. In52`MyHyperModel.build()`, we build a simple Keras model to do image53classification for 10 different classes. `MyHyperModel.fit()` accepts several54arguments. Its signature is shown below:5556```python57def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):58```5960* The `hp` argument is for defining the hyperparameters.61* The `model` argument is the model returned by `MyHyperModel.build()`.62* `x`, `y`, and `validation_data` are all custom-defined arguments. We will63pass our data to them by calling `tuner.search(x=x, y=y,64validation_data=(x_val, y_val))` later. You can define any number of them and65give custom names.66* The `callbacks` argument was intended to be used with `model.fit()`.67KerasTuner put some helpful Keras callbacks in it, for example, the callback68for checkpointing the model at its best epoch.6970We will manually call the callbacks in the custom training loop. Before we71can call them, we need to assign our model to them with the following code so72that they have access to the model for checkpointing.7374```py75for callback in callbacks:76callback.model = model77```7879In this example, we only called the `on_epoch_end()` method of the callbacks80to help us checkpoint the model. You may also call other callback methods81if needed. If you don't need to save the model, you don't need to use the82callbacks.8384In the custom training loop, we tune the batch size of the dataset as we wrap85the NumPy data into a `tf.data.Dataset`. Note that you can tune any86preprocessing steps here as well. We also tune the learning rate of the87optimizer.8889We will use the validation loss as the evaluation metric for the model. To90compute the mean validation loss, we will use `keras.metrics.Mean()`, which91averages the validation loss across the batches. We need to return the92validation loss for the tuner to make a record.93"""949596class MyHyperModel(keras_tuner.HyperModel):97def build(self, hp):98"""Builds a convolutional model."""99inputs = keras.Input(shape=(28, 28, 1))100x = keras.layers.Flatten()(inputs)101x = keras.layers.Dense(102units=hp.Choice("units", [32, 64, 128]), activation="relu"103)(x)104outputs = keras.layers.Dense(10)(x)105return keras.Model(inputs=inputs, outputs=outputs)106107def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):108# Convert the datasets to tf.data.Dataset.109batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)110train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(111batch_size112)113validation_data = tf.data.Dataset.from_tensor_slices(validation_data).batch(114batch_size115)116117# Define the optimizer.118optimizer = keras.optimizers.Adam(119hp.Float("learning_rate", 1e-4, 1e-2, sampling="log", default=1e-3)120)121loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)122123# The metric to track validation loss.124epoch_loss_metric = keras.metrics.Mean()125126# Function to run the train step.127@tf.function128def run_train_step(images, labels):129with tf.GradientTape() as tape:130logits = model(images)131loss = loss_fn(labels, logits)132# Add any regularization losses.133if model.losses:134loss += tf.math.add_n(model.losses)135gradients = tape.gradient(loss, model.trainable_variables)136optimizer.apply_gradients(zip(gradients, model.trainable_variables))137138# Function to run the validation step.139@tf.function140def run_val_step(images, labels):141logits = model(images)142loss = loss_fn(labels, logits)143# Update the metric.144epoch_loss_metric.update_state(loss)145146# Assign the model to the callbacks.147for callback in callbacks:148callback.set_model(model)149150# Record the best validation loss value151best_epoch_loss = float("inf")152153# The custom training loop.154for epoch in range(2):155print(f"Epoch: {epoch}")156157# Iterate the training data to run the training step.158for images, labels in train_ds:159run_train_step(images, labels)160161# Iterate the validation data to run the validation step.162for images, labels in validation_data:163run_val_step(images, labels)164165# Calling the callbacks after epoch.166epoch_loss = float(epoch_loss_metric.result().numpy())167for callback in callbacks:168# The "my_metric" is the objective passed to the tuner.169callback.on_epoch_end(epoch, logs={"my_metric": epoch_loss})170epoch_loss_metric.reset_state()171172print(f"Epoch loss: {epoch_loss}")173best_epoch_loss = min(best_epoch_loss, epoch_loss)174175# Return the evaluation metric value.176return best_epoch_loss177178179"""180Now, we can initialize the tuner. Here, we use `Objective("my_metric", "min")`181as our metric to be minimized. The objective name should be consistent with the182one you use as the key in the `logs` passed to the 'on_epoch_end()' method of183the callbacks. The callbacks need to use this value in the `logs` to find the184best epoch to checkpoint the model.185186"""187tuner = keras_tuner.RandomSearch(188objective=keras_tuner.Objective("my_metric", "min"),189max_trials=2,190hypermodel=MyHyperModel(),191directory="results",192project_name="custom_training",193overwrite=True,194)195196197"""198We start the search by passing the arguments we defined in the signature of199`MyHyperModel.fit()` to `tuner.search()`.200"""201202tuner.search(x=x_train, y=y_train, validation_data=(x_val, y_val))203204"""205Finally, we can retrieve the results.206"""207208best_hps = tuner.get_best_hyperparameters()[0]209print(best_hps.values)210211best_model = tuner.get_best_models()[0]212best_model.summary()213214"""215In summary, to tune the hyperparameters in your custom training loop, you just216override `HyperModel.fit()` to train the model and return the evaluation217results. With the provided callbacks, you can easily save the trained models at218their best epochs and load the best models later.219220To find out more about the basics of KerasTuner, please see221[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/).222"""223224225