Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_tuner/custom_tuner.py
3293 views
1
"""
2
Title: Tune hyperparameters in your custom training loop
3
Authors: Tom O'Malley, Haifeng Jin
4
Date created: 2019/10/28
5
Last modified: 2022/01/12
6
Description: Use `HyperModel.fit()` to tune training hyperparameters (such as batch size).
7
Accelerator: GPU
8
"""
9
10
"""shell
11
pip install keras-tuner -q
12
"""
13
14
"""
15
## Introduction
16
17
The `HyperModel` class in KerasTuner provides a convenient way to define your
18
search space in a reusable object. You can override `HyperModel.build()` to
19
define and hypertune the model itself. To hypertune the training process (e.g.
20
by selecting the proper batch size, number of training epochs, or data
21
augmentation setup), you can override `HyperModel.fit()`, where you can access:
22
23
- The `hp` object, which is an instance of `keras_tuner.HyperParameters`
24
- The model built by `HyperModel.build()`
25
26
A basic example is shown in the "tune model training" section of
27
[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/#tune-model-training).
28
29
## Tuning the custom training loop
30
31
In this guide, we will subclass the `HyperModel` class and write a custom
32
training loop by overriding `HyperModel.fit()`. For how to write a custom
33
training loop with Keras, you can refer to the guide
34
[Writing a training loop from scratch](https://keras.io/guides/writing_a_training_loop_from_scratch/).
35
36
First, we import the libraries we need, and we create datasets for training and
37
validation. Here, we just use some random data for demonstration purposes.
38
"""
39
40
import keras_tuner
41
import tensorflow as tf
42
import keras
43
import numpy as np
44
45
46
x_train = np.random.rand(1000, 28, 28, 1)
47
y_train = np.random.randint(0, 10, (1000, 1))
48
x_val = np.random.rand(1000, 28, 28, 1)
49
y_val = np.random.randint(0, 10, (1000, 1))
50
51
"""
52
Then, we subclass the `HyperModel` class as `MyHyperModel`. In
53
`MyHyperModel.build()`, we build a simple Keras model to do image
54
classification for 10 different classes. `MyHyperModel.fit()` accepts several
55
arguments. Its signature is shown below:
56
57
```python
58
def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):
59
```
60
61
* The `hp` argument is for defining the hyperparameters.
62
* The `model` argument is the model returned by `MyHyperModel.build()`.
63
* `x`, `y`, and `validation_data` are all custom-defined arguments. We will
64
pass our data to them by calling `tuner.search(x=x, y=y,
65
validation_data=(x_val, y_val))` later. You can define any number of them and
66
give custom names.
67
* The `callbacks` argument was intended to be used with `model.fit()`.
68
KerasTuner put some helpful Keras callbacks in it, for example, the callback
69
for checkpointing the model at its best epoch.
70
71
We will manually call the callbacks in the custom training loop. Before we
72
can call them, we need to assign our model to them with the following code so
73
that they have access to the model for checkpointing.
74
75
```py
76
for callback in callbacks:
77
callback.model = model
78
```
79
80
In this example, we only called the `on_epoch_end()` method of the callbacks
81
to help us checkpoint the model. You may also call other callback methods
82
if needed. If you don't need to save the model, you don't need to use the
83
callbacks.
84
85
In the custom training loop, we tune the batch size of the dataset as we wrap
86
the NumPy data into a `tf.data.Dataset`. Note that you can tune any
87
preprocessing steps here as well. We also tune the learning rate of the
88
optimizer.
89
90
We will use the validation loss as the evaluation metric for the model. To
91
compute the mean validation loss, we will use `keras.metrics.Mean()`, which
92
averages the validation loss across the batches. We need to return the
93
validation loss for the tuner to make a record.
94
"""
95
96
97
class MyHyperModel(keras_tuner.HyperModel):
98
def build(self, hp):
99
"""Builds a convolutional model."""
100
inputs = keras.Input(shape=(28, 28, 1))
101
x = keras.layers.Flatten()(inputs)
102
x = keras.layers.Dense(
103
units=hp.Choice("units", [32, 64, 128]), activation="relu"
104
)(x)
105
outputs = keras.layers.Dense(10)(x)
106
return keras.Model(inputs=inputs, outputs=outputs)
107
108
def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):
109
# Convert the datasets to tf.data.Dataset.
110
batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)
111
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(
112
batch_size
113
)
114
validation_data = tf.data.Dataset.from_tensor_slices(validation_data).batch(
115
batch_size
116
)
117
118
# Define the optimizer.
119
optimizer = keras.optimizers.Adam(
120
hp.Float("learning_rate", 1e-4, 1e-2, sampling="log", default=1e-3)
121
)
122
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
123
124
# The metric to track validation loss.
125
epoch_loss_metric = keras.metrics.Mean()
126
127
# Function to run the train step.
128
@tf.function
129
def run_train_step(images, labels):
130
with tf.GradientTape() as tape:
131
logits = model(images)
132
loss = loss_fn(labels, logits)
133
# Add any regularization losses.
134
if model.losses:
135
loss += tf.math.add_n(model.losses)
136
gradients = tape.gradient(loss, model.trainable_variables)
137
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
138
139
# Function to run the validation step.
140
@tf.function
141
def run_val_step(images, labels):
142
logits = model(images)
143
loss = loss_fn(labels, logits)
144
# Update the metric.
145
epoch_loss_metric.update_state(loss)
146
147
# Assign the model to the callbacks.
148
for callback in callbacks:
149
callback.set_model(model)
150
151
# Record the best validation loss value
152
best_epoch_loss = float("inf")
153
154
# The custom training loop.
155
for epoch in range(2):
156
print(f"Epoch: {epoch}")
157
158
# Iterate the training data to run the training step.
159
for images, labels in train_ds:
160
run_train_step(images, labels)
161
162
# Iterate the validation data to run the validation step.
163
for images, labels in validation_data:
164
run_val_step(images, labels)
165
166
# Calling the callbacks after epoch.
167
epoch_loss = float(epoch_loss_metric.result().numpy())
168
for callback in callbacks:
169
# The "my_metric" is the objective passed to the tuner.
170
callback.on_epoch_end(epoch, logs={"my_metric": epoch_loss})
171
epoch_loss_metric.reset_state()
172
173
print(f"Epoch loss: {epoch_loss}")
174
best_epoch_loss = min(best_epoch_loss, epoch_loss)
175
176
# Return the evaluation metric value.
177
return best_epoch_loss
178
179
180
"""
181
Now, we can initialize the tuner. Here, we use `Objective("my_metric", "min")`
182
as our metric to be minimized. The objective name should be consistent with the
183
one you use as the key in the `logs` passed to the 'on_epoch_end()' method of
184
the callbacks. The callbacks need to use this value in the `logs` to find the
185
best epoch to checkpoint the model.
186
187
"""
188
tuner = keras_tuner.RandomSearch(
189
objective=keras_tuner.Objective("my_metric", "min"),
190
max_trials=2,
191
hypermodel=MyHyperModel(),
192
directory="results",
193
project_name="custom_training",
194
overwrite=True,
195
)
196
197
198
"""
199
We start the search by passing the arguments we defined in the signature of
200
`MyHyperModel.fit()` to `tuner.search()`.
201
"""
202
203
tuner.search(x=x_train, y=y_train, validation_data=(x_val, y_val))
204
205
"""
206
Finally, we can retrieve the results.
207
"""
208
209
best_hps = tuner.get_best_hyperparameters()[0]
210
print(best_hps.values)
211
212
best_model = tuner.get_best_models()[0]
213
best_model.summary()
214
215
"""
216
In summary, to tune the hyperparameters in your custom training loop, you just
217
override `HyperModel.fit()` to train the model and return the evaluation
218
results. With the provided callbacks, you can easily save the trained models at
219
their best epochs and load the best models later.
220
221
To find out more about the basics of KerasTuner, please see
222
[Getting Started with KerasTuner](https://keras.io/guides/keras_tuner/getting_started/).
223
"""
224
225