Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_tuner/custom_tuner.py
8301 views
1
"""
2
Title: Tune hyperparameters in your custom training loop
3
Authors: Tom O'Malley, Haifeng Jin
4
Date created: 2019/10/28
5
Last modified: 2022/01/12
6
Description: Use `HyperModel.fit()` to tune training hyperparameters (such as batch size).
7
Accelerator: GPU
8
"""
9
10
"""shell
11
pip install keras-tuner -q
12
"""
13
14
"""
15
## Introduction
16
17
The `HyperModel` class in KerasTuner provides a convenient way to define your
18
search space in a reusable object. You can override `HyperModel.build()` to
19
define and hypertune the model itself. To hypertune the training process (e.g.
20
by selecting the proper batch size, number of training epochs, or data
21
augmentation setup), you can override `HyperModel.fit()`, where you can access:
22
23
- The `hp` object, which is an instance of `keras_tuner.HyperParameters`
24
- The model built by `HyperModel.build()`
25
26
A basic example is shown in the "tune model training" section of
27
[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/#tune-model-training).
28
29
## Tuning the custom training loop
30
31
In this guide, we will subclass the `HyperModel` class and write a custom
32
training loop by overriding `HyperModel.fit()`. For how to write a custom
33
training loop with Keras, you can refer to the guide
34
[Writing a training loop from scratch](https://keras.io/guides/writing_a_training_loop_from_scratch/).
35
36
First, we import the libraries we need, and we create datasets for training and
37
validation. Here, we just use some random data for demonstration purposes.
38
"""
39
40
import keras_tuner
41
import tensorflow as tf
42
import keras
43
import numpy as np
44
45
x_train = np.random.rand(1000, 28, 28, 1)
46
y_train = np.random.randint(0, 10, (1000, 1))
47
x_val = np.random.rand(1000, 28, 28, 1)
48
y_val = np.random.randint(0, 10, (1000, 1))
49
50
"""
51
Then, we subclass the `HyperModel` class as `MyHyperModel`. In
52
`MyHyperModel.build()`, we build a simple Keras model to do image
53
classification for 10 different classes. `MyHyperModel.fit()` accepts several
54
arguments. Its signature is shown below:
55
56
```python
57
def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):
58
```
59
60
* The `hp` argument is for defining the hyperparameters.
61
* The `model` argument is the model returned by `MyHyperModel.build()`.
62
* `x`, `y`, and `validation_data` are all custom-defined arguments. We will
63
pass our data to them by calling `tuner.search(x=x, y=y,
64
validation_data=(x_val, y_val))` later. You can define any number of them and
65
give custom names.
66
* The `callbacks` argument was intended to be used with `model.fit()`.
67
KerasTuner put some helpful Keras callbacks in it, for example, the callback
68
for checkpointing the model at its best epoch.
69
70
We will manually call the callbacks in the custom training loop. Before we
71
can call them, we need to assign our model to them with the following code so
72
that they have access to the model for checkpointing.
73
74
```py
75
for callback in callbacks:
76
callback.model = model
77
```
78
79
In this example, we only called the `on_epoch_end()` method of the callbacks
80
to help us checkpoint the model. You may also call other callback methods
81
if needed. If you don't need to save the model, you don't need to use the
82
callbacks.
83
84
In the custom training loop, we tune the batch size of the dataset as we wrap
85
the NumPy data into a `tf.data.Dataset`. Note that you can tune any
86
preprocessing steps here as well. We also tune the learning rate of the
87
optimizer.
88
89
We will use the validation loss as the evaluation metric for the model. To
90
compute the mean validation loss, we will use `keras.metrics.Mean()`, which
91
averages the validation loss across the batches. We need to return the
92
validation loss for the tuner to make a record.
93
"""
94
95
96
class MyHyperModel(keras_tuner.HyperModel):
97
def build(self, hp):
98
"""Builds a convolutional model."""
99
inputs = keras.Input(shape=(28, 28, 1))
100
x = keras.layers.Flatten()(inputs)
101
x = keras.layers.Dense(
102
units=hp.Choice("units", [32, 64, 128]), activation="relu"
103
)(x)
104
outputs = keras.layers.Dense(10)(x)
105
return keras.Model(inputs=inputs, outputs=outputs)
106
107
def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):
108
# Convert the datasets to tf.data.Dataset.
109
batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)
110
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(
111
batch_size
112
)
113
validation_data = tf.data.Dataset.from_tensor_slices(validation_data).batch(
114
batch_size
115
)
116
117
# Define the optimizer.
118
optimizer = keras.optimizers.Adam(
119
hp.Float("learning_rate", 1e-4, 1e-2, sampling="log", default=1e-3)
120
)
121
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
122
123
# The metric to track validation loss.
124
epoch_loss_metric = keras.metrics.Mean()
125
126
# Function to run the train step.
127
@tf.function
128
def run_train_step(images, labels):
129
with tf.GradientTape() as tape:
130
logits = model(images)
131
loss = loss_fn(labels, logits)
132
# Add any regularization losses.
133
if model.losses:
134
loss += tf.math.add_n(model.losses)
135
gradients = tape.gradient(loss, model.trainable_variables)
136
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
137
138
# Function to run the validation step.
139
@tf.function
140
def run_val_step(images, labels):
141
logits = model(images)
142
loss = loss_fn(labels, logits)
143
# Update the metric.
144
epoch_loss_metric.update_state(loss)
145
146
# Assign the model to the callbacks.
147
for callback in callbacks:
148
callback.set_model(model)
149
150
# Record the best validation loss value
151
best_epoch_loss = float("inf")
152
153
# The custom training loop.
154
for epoch in range(2):
155
print(f"Epoch: {epoch}")
156
157
# Iterate the training data to run the training step.
158
for images, labels in train_ds:
159
run_train_step(images, labels)
160
161
# Iterate the validation data to run the validation step.
162
for images, labels in validation_data:
163
run_val_step(images, labels)
164
165
# Calling the callbacks after epoch.
166
epoch_loss = float(epoch_loss_metric.result().numpy())
167
for callback in callbacks:
168
# The "my_metric" is the objective passed to the tuner.
169
callback.on_epoch_end(epoch, logs={"my_metric": epoch_loss})
170
epoch_loss_metric.reset_state()
171
172
print(f"Epoch loss: {epoch_loss}")
173
best_epoch_loss = min(best_epoch_loss, epoch_loss)
174
175
# Return the evaluation metric value.
176
return best_epoch_loss
177
178
179
"""
180
Now, we can initialize the tuner. Here, we use `Objective("my_metric", "min")`
181
as our metric to be minimized. The objective name should be consistent with the
182
one you use as the key in the `logs` passed to the 'on_epoch_end()' method of
183
the callbacks. The callbacks need to use this value in the `logs` to find the
184
best epoch to checkpoint the model.
185
186
"""
187
tuner = keras_tuner.RandomSearch(
188
objective=keras_tuner.Objective("my_metric", "min"),
189
max_trials=2,
190
hypermodel=MyHyperModel(),
191
directory="results",
192
project_name="custom_training",
193
overwrite=True,
194
)
195
196
197
"""
198
We start the search by passing the arguments we defined in the signature of
199
`MyHyperModel.fit()` to `tuner.search()`.
200
"""
201
202
tuner.search(x=x_train, y=y_train, validation_data=(x_val, y_val))
203
204
"""
205
Finally, we can retrieve the results.
206
"""
207
208
best_hps = tuner.get_best_hyperparameters()[0]
209
print(best_hps.values)
210
211
best_model = tuner.get_best_models()[0]
212
best_model.summary()
213
214
"""
215
In summary, to tune the hyperparameters in your custom training loop, you just
216
override `HyperModel.fit()` to train the model and return the evaluation
217
results. With the provided callbacks, you can easily save the trained models at
218
their best epochs and load the best models later.
219
220
To find out more about the basics of KerasTuner, please see
221
[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/).
222
"""
223
224