Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/keras_tuner/getting_started.py
3293 views
1
"""
2
Title: Getting started with KerasTuner
3
Authors: Luca Invernizzi, James Long, Francois Chollet, Tom O'Malley, Haifeng Jin
4
Date created: 2019/05/31
5
Last modified: 2021/10/27
6
Description: The basics of using KerasTuner to tune model hyperparameters.
7
Accelerator: GPU
8
"""
9
10
"""shell
11
pip install keras-tuner -q
12
"""
13
14
15
"""
16
## Introduction
17
18
KerasTuner is a general-purpose hyperparameter tuning library. It has strong
19
integration with Keras workflows, but it isn't limited to them: you could use
20
it to tune scikit-learn models, or anything else. In this tutorial, you will
21
see how to tune model architecture, training process, and data preprocessing
22
steps with KerasTuner. Let's start from a simple example.
23
24
## Tune the model architecture
25
26
The first thing we need to do is writing a function, which returns a compiled
27
Keras model. It takes an argument `hp` for defining the hyperparameters while
28
building the model.
29
30
### Define the search space
31
32
In the following code example, we define a Keras model with two `Dense` layers.
33
We want to tune the number of units in the first `Dense` layer. We just define
34
an integer hyperparameter with `hp.Int('units', min_value=32, max_value=512, step=32)`,
35
whose range is from 32 to 512 inclusive. When sampling from it, the minimum
36
step for walking through the interval is 32.
37
"""
38
39
import keras
40
from keras import layers
41
42
43
def build_model(hp):
44
model = keras.Sequential()
45
model.add(layers.Flatten())
46
model.add(
47
layers.Dense(
48
# Define the hyperparameter.
49
units=hp.Int("units", min_value=32, max_value=512, step=32),
50
activation="relu",
51
)
52
)
53
model.add(layers.Dense(10, activation="softmax"))
54
model.compile(
55
optimizer="adam",
56
loss="categorical_crossentropy",
57
metrics=["accuracy"],
58
)
59
return model
60
61
62
"""
63
You can quickly test if the model builds successfully.
64
"""
65
66
import keras_tuner
67
68
build_model(keras_tuner.HyperParameters())
69
70
"""
71
There are many other types of hyperparameters as well. We can define multiple
72
hyperparameters in the function. In the following code, we tune whether to
73
use a `Dropout` layer with `hp.Boolean()`, tune which activation function to
74
use with `hp.Choice()`, tune the learning rate of the optimizer with
75
`hp.Float()`.
76
"""
77
78
79
def build_model(hp):
80
model = keras.Sequential()
81
model.add(layers.Flatten())
82
model.add(
83
layers.Dense(
84
# Tune number of units.
85
units=hp.Int("units", min_value=32, max_value=512, step=32),
86
# Tune the activation function to use.
87
activation=hp.Choice("activation", ["relu", "tanh"]),
88
)
89
)
90
# Tune whether to use dropout.
91
if hp.Boolean("dropout"):
92
model.add(layers.Dropout(rate=0.25))
93
model.add(layers.Dense(10, activation="softmax"))
94
# Define the optimizer learning rate as a hyperparameter.
95
learning_rate = hp.Float("lr", min_value=1e-4, max_value=1e-2, sampling="log")
96
model.compile(
97
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
98
loss="categorical_crossentropy",
99
metrics=["accuracy"],
100
)
101
return model
102
103
104
build_model(keras_tuner.HyperParameters())
105
106
"""
107
As shown below, the hyperparameters are actual values. In fact, they are just
108
functions returning actual values. For example, `hp.Int()` returns an `int`
109
value. Therefore, you can put them into variables, for loops, or if
110
conditions.
111
"""
112
113
hp = keras_tuner.HyperParameters()
114
print(hp.Int("units", min_value=32, max_value=512, step=32))
115
116
"""
117
You can also define the hyperparameters in advance and keep your Keras code in
118
a separate function.
119
"""
120
121
122
def call_existing_code(units, activation, dropout, lr):
123
model = keras.Sequential()
124
model.add(layers.Flatten())
125
model.add(layers.Dense(units=units, activation=activation))
126
if dropout:
127
model.add(layers.Dropout(rate=0.25))
128
model.add(layers.Dense(10, activation="softmax"))
129
model.compile(
130
optimizer=keras.optimizers.Adam(learning_rate=lr),
131
loss="categorical_crossentropy",
132
metrics=["accuracy"],
133
)
134
return model
135
136
137
def build_model(hp):
138
units = hp.Int("units", min_value=32, max_value=512, step=32)
139
activation = hp.Choice("activation", ["relu", "tanh"])
140
dropout = hp.Boolean("dropout")
141
lr = hp.Float("lr", min_value=1e-4, max_value=1e-2, sampling="log")
142
# call existing model-building code with the hyperparameter values.
143
model = call_existing_code(
144
units=units, activation=activation, dropout=dropout, lr=lr
145
)
146
return model
147
148
149
build_model(keras_tuner.HyperParameters())
150
151
"""
152
Each of the hyperparameters is uniquely identified by its name (the first
153
argument). To tune the number of units in different `Dense` layers separately
154
as different hyperparameters, we give them different names as `f"units_{i}"`.
155
156
Notably, this is also an example of creating conditional hyperparameters.
157
There are many hyperparameters specifying the number of units in the `Dense`
158
layers. The number of such hyperparameters is decided by the number of layers,
159
which is also a hyperparameter. Therefore, the total number of hyperparameters
160
used may be different from trial to trial. Some hyperparameter is only used
161
when a certain condition is satisfied. For example, `units_3` is only used
162
when `num_layers` is larger than 3. With KerasTuner, you can easily define
163
such hyperparameters dynamically while creating the model.
164
165
"""
166
167
168
def build_model(hp):
169
model = keras.Sequential()
170
model.add(layers.Flatten())
171
# Tune the number of layers.
172
for i in range(hp.Int("num_layers", 1, 3)):
173
model.add(
174
layers.Dense(
175
# Tune number of units separately.
176
units=hp.Int(f"units_{i}", min_value=32, max_value=512, step=32),
177
activation=hp.Choice("activation", ["relu", "tanh"]),
178
)
179
)
180
if hp.Boolean("dropout"):
181
model.add(layers.Dropout(rate=0.25))
182
model.add(layers.Dense(10, activation="softmax"))
183
learning_rate = hp.Float("lr", min_value=1e-4, max_value=1e-2, sampling="log")
184
model.compile(
185
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
186
loss="categorical_crossentropy",
187
metrics=["accuracy"],
188
)
189
return model
190
191
192
build_model(keras_tuner.HyperParameters())
193
194
"""
195
### Start the search
196
197
After defining the search space, we need to select a tuner class to run the
198
search. You may choose from `RandomSearch`, `BayesianOptimization` and
199
`Hyperband`, which correspond to different tuning algorithms. Here we use
200
`RandomSearch` as an example.
201
202
To initialize the tuner, we need to specify several arguments in the initializer.
203
204
* `hypermodel`. The model-building function, which is `build_model` in our case.
205
* `objective`. The name of the objective to optimize (whether to minimize or
206
maximize is automatically inferred for built-in metrics). We will introduce how
207
to use custom metrics later in this tutorial.
208
* `max_trials`. The total number of trials to run during the search.
209
* `executions_per_trial`. The number of models that should be built and fit for
210
each trial. Different trials have different hyperparameter values. The
211
executions within the same trial have the same hyperparameter values. The
212
purpose of having multiple executions per trial is to reduce results variance
213
and therefore be able to more accurately assess the performance of a model. If
214
you want to get results faster, you could set `executions_per_trial=1` (single
215
round of training for each model configuration).
216
* `overwrite`. Control whether to overwrite the previous results in the same
217
directory or resume the previous search instead. Here we set `overwrite=True`
218
to start a new search and ignore any previous results.
219
* `directory`. A path to a directory for storing the search results.
220
* `project_name`. The name of the sub-directory in the `directory`.
221
222
"""
223
224
tuner = keras_tuner.RandomSearch(
225
hypermodel=build_model,
226
objective="val_accuracy",
227
max_trials=3,
228
executions_per_trial=2,
229
overwrite=True,
230
directory="my_dir",
231
project_name="helloworld",
232
)
233
234
"""
235
You can print a summary of the search space:
236
"""
237
238
tuner.search_space_summary()
239
240
"""
241
Before starting the search, let's prepare the MNIST dataset.
242
"""
243
244
import keras
245
import numpy as np
246
247
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
248
249
x_train = x[:-10000]
250
x_val = x[-10000:]
251
y_train = y[:-10000]
252
y_val = y[-10000:]
253
254
x_train = np.expand_dims(x_train, -1).astype("float32") / 255.0
255
x_val = np.expand_dims(x_val, -1).astype("float32") / 255.0
256
x_test = np.expand_dims(x_test, -1).astype("float32") / 255.0
257
258
num_classes = 10
259
y_train = keras.utils.to_categorical(y_train, num_classes)
260
y_val = keras.utils.to_categorical(y_val, num_classes)
261
y_test = keras.utils.to_categorical(y_test, num_classes)
262
263
"""
264
Then, start the search for the best hyperparameter configuration.
265
All the arguments passed to `search` is passed to `model.fit()` in each
266
execution. Remember to pass `validation_data` to evaluate the model.
267
"""
268
269
tuner.search(x_train, y_train, epochs=2, validation_data=(x_val, y_val))
270
271
"""
272
During the `search`, the model-building function is called with different
273
hyperparameter values in different trial. In each trial, the tuner would
274
generate a new set of hyperparameter values to build the model. The model is
275
then fit and evaluated. The metrics are recorded. The tuner progressively
276
explores the space and finally finds a good set of hyperparameter values.
277
278
### Query the results
279
280
When search is over, you can retrieve the best model(s). The model is saved at
281
its best performing epoch evaluated on the `validation_data`.
282
"""
283
284
# Get the top 2 models.
285
models = tuner.get_best_models(num_models=2)
286
best_model = models[0]
287
best_model.summary()
288
289
"""
290
You can also print a summary of the search results.
291
"""
292
293
tuner.results_summary()
294
295
"""
296
You will find detailed logs, checkpoints, etc, in the folder
297
`my_dir/helloworld`, i.e. `directory/project_name`.
298
299
You can also visualize the tuning results using TensorBoard and HParams plugin.
300
For more information, please following
301
[this link](https://keras.io/guides/keras_tuner/visualize_tuning/).
302
303
### Retrain the model
304
305
If you want to train the model with the entire dataset, you may retrieve the
306
best hyperparameters and retrain the model by yourself.
307
"""
308
309
# Get the top 2 hyperparameters.
310
best_hps = tuner.get_best_hyperparameters(5)
311
# Build the model with the best hp.
312
model = build_model(best_hps[0])
313
# Fit with the entire dataset.
314
x_all = np.concatenate((x_train, x_val))
315
y_all = np.concatenate((y_train, y_val))
316
model.fit(x=x_all, y=y_all, epochs=1)
317
318
"""
319
## Tune model training
320
321
To tune the model building process, we need to subclass the `HyperModel` class,
322
which also makes it easy to share and reuse hypermodels.
323
324
We need to override `HyperModel.build()` and `HyperModel.fit()` to tune the
325
model building and training process respectively. A `HyperModel.build()`
326
method is the same as the model-building function, which creates a Keras model
327
using the hyperparameters and returns it.
328
329
In `HyperModel.fit()`, you can access the model returned by
330
`HyperModel.build()`,`hp` and all the arguments passed to `search()`. You need
331
to train the model and return the training history.
332
333
In the following code, we will tune the `shuffle` argument in `model.fit()`.
334
335
It is generally not needed to tune the number of epochs because a built-in
336
callback is passed to `model.fit()` to save the model at its best epoch
337
evaluated by the `validation_data`.
338
339
> **Note**: The `**kwargs` should always be passed to `model.fit()` because it
340
contains the callbacks for model saving and tensorboard plugins.
341
"""
342
343
344
class MyHyperModel(keras_tuner.HyperModel):
345
def build(self, hp):
346
model = keras.Sequential()
347
model.add(layers.Flatten())
348
model.add(
349
layers.Dense(
350
units=hp.Int("units", min_value=32, max_value=512, step=32),
351
activation="relu",
352
)
353
)
354
model.add(layers.Dense(10, activation="softmax"))
355
model.compile(
356
optimizer="adam",
357
loss="categorical_crossentropy",
358
metrics=["accuracy"],
359
)
360
return model
361
362
def fit(self, hp, model, *args, **kwargs):
363
return model.fit(
364
*args,
365
# Tune whether to shuffle the data in each epoch.
366
shuffle=hp.Boolean("shuffle"),
367
**kwargs,
368
)
369
370
371
"""
372
Again, we can do a quick check to see if the code works correctly.
373
"""
374
375
hp = keras_tuner.HyperParameters()
376
hypermodel = MyHyperModel()
377
model = hypermodel.build(hp)
378
hypermodel.fit(hp, model, np.random.rand(100, 28, 28), np.random.rand(100, 10))
379
380
"""
381
## Tune data preprocessing
382
383
To tune data preprocessing, we just add an additional step in
384
`HyperModel.fit()`, where we can access the dataset from the arguments. In the
385
following code, we tune whether to normalize the data before training the
386
model. This time we explicitly put `x` and `y` in the function signature
387
because we need to use them.
388
389
"""
390
391
392
class MyHyperModel(keras_tuner.HyperModel):
393
def build(self, hp):
394
model = keras.Sequential()
395
model.add(layers.Flatten())
396
model.add(
397
layers.Dense(
398
units=hp.Int("units", min_value=32, max_value=512, step=32),
399
activation="relu",
400
)
401
)
402
model.add(layers.Dense(10, activation="softmax"))
403
model.compile(
404
optimizer="adam",
405
loss="categorical_crossentropy",
406
metrics=["accuracy"],
407
)
408
return model
409
410
def fit(self, hp, model, x, y, **kwargs):
411
if hp.Boolean("normalize"):
412
x = layers.Normalization()(x)
413
return model.fit(
414
x,
415
y,
416
# Tune whether to shuffle the data in each epoch.
417
shuffle=hp.Boolean("shuffle"),
418
**kwargs,
419
)
420
421
422
hp = keras_tuner.HyperParameters()
423
hypermodel = MyHyperModel()
424
model = hypermodel.build(hp)
425
hypermodel.fit(hp, model, np.random.rand(100, 28, 28), np.random.rand(100, 10))
426
427
"""
428
If a hyperparameter is used both in `build()` and `fit()`, you can define it in
429
`build()` and use `hp.get(hp_name)` to retrieve it in `fit()`. We use the
430
image size as an example. It is both used as the input shape in `build()`, and
431
used by data prerprocessing step to crop the images in `fit()`.
432
"""
433
434
435
class MyHyperModel(keras_tuner.HyperModel):
436
def build(self, hp):
437
image_size = hp.Int("image_size", 10, 28)
438
inputs = keras.Input(shape=(image_size, image_size))
439
outputs = layers.Flatten()(inputs)
440
outputs = layers.Dense(
441
units=hp.Int("units", min_value=32, max_value=512, step=32),
442
activation="relu",
443
)(outputs)
444
outputs = layers.Dense(10, activation="softmax")(outputs)
445
model = keras.Model(inputs, outputs)
446
model.compile(
447
optimizer="adam",
448
loss="categorical_crossentropy",
449
metrics=["accuracy"],
450
)
451
return model
452
453
def fit(self, hp, model, x, y, validation_data=None, **kwargs):
454
if hp.Boolean("normalize"):
455
x = layers.Normalization()(x)
456
image_size = hp.get("image_size")
457
cropped_x = x[:, :image_size, :image_size, :]
458
if validation_data:
459
x_val, y_val = validation_data
460
cropped_x_val = x_val[:, :image_size, :image_size, :]
461
validation_data = (cropped_x_val, y_val)
462
return model.fit(
463
cropped_x,
464
y,
465
# Tune whether to shuffle the data in each epoch.
466
shuffle=hp.Boolean("shuffle"),
467
validation_data=validation_data,
468
**kwargs,
469
)
470
471
472
tuner = keras_tuner.RandomSearch(
473
MyHyperModel(),
474
objective="val_accuracy",
475
max_trials=3,
476
overwrite=True,
477
directory="my_dir",
478
project_name="tune_hypermodel",
479
)
480
481
tuner.search(x_train, y_train, epochs=2, validation_data=(x_val, y_val))
482
483
"""
484
### Retrain the model
485
486
Using `HyperModel` also allows you to retrain the best model by yourself.
487
"""
488
489
hypermodel = MyHyperModel()
490
best_hp = tuner.get_best_hyperparameters()[0]
491
model = hypermodel.build(best_hp)
492
hypermodel.fit(best_hp, model, x_all, y_all, epochs=1)
493
494
"""
495
## Specify the tuning objective
496
497
In all previous examples, we all just used validation accuracy
498
(`"val_accuracy"`) as the tuning objective to select the best model. Actually,
499
you can use any metric as the objective. The most commonly used metric is
500
`"val_loss"`, which is the validation loss.
501
502
### Built-in metric as the objective
503
504
There are many other built-in metrics in Keras you can use as the objective.
505
Here is [a list of the built-in metrics](https://keras.io/api/metrics/).
506
507
To use a built-in metric as the objective, you need to follow these steps:
508
509
* Compile the model with the the built-in metric. For example, you want to use
510
`MeanAbsoluteError()`. You need to compile the model with
511
`metrics=[MeanAbsoluteError()]`. You may also use its name string instead:
512
`metrics=["mean_absolute_error"]`. The name string of the metric is always
513
the snake case of the class name.
514
515
* Identify the objective name string. The name string of the objective is
516
always in the format of `f"val_{metric_name_string}"`. For example, the
517
objective name string of mean squared error evaluated on the validation data
518
should be `"val_mean_absolute_error"`.
519
520
* Wrap it into `keras_tuner.Objective`. We usually need to wrap the objective
521
into a `keras_tuner.Objective` object to specify the direction to optimize the
522
objective. For example, we want to minimize the mean squared error, we can use
523
`keras_tuner.Objective("val_mean_absolute_error", "min")`. The direction should
524
be either `"min"` or `"max"`.
525
526
* Pass the wrapped objective to the tuner.
527
528
You can see the following barebone code example.
529
"""
530
531
532
def build_regressor(hp):
533
model = keras.Sequential(
534
[
535
layers.Dense(units=hp.Int("units", 32, 128, 32), activation="relu"),
536
layers.Dense(units=1),
537
]
538
)
539
model.compile(
540
optimizer="adam",
541
loss="mean_squared_error",
542
# Objective is one of the metrics.
543
metrics=[keras.metrics.MeanAbsoluteError()],
544
)
545
return model
546
547
548
tuner = keras_tuner.RandomSearch(
549
hypermodel=build_regressor,
550
# The objective name and direction.
551
# Name is the f"val_{snake_case_metric_class_name}".
552
objective=keras_tuner.Objective("val_mean_absolute_error", direction="min"),
553
max_trials=3,
554
overwrite=True,
555
directory="my_dir",
556
project_name="built_in_metrics",
557
)
558
559
tuner.search(
560
x=np.random.rand(100, 10),
561
y=np.random.rand(100, 1),
562
validation_data=(np.random.rand(20, 10), np.random.rand(20, 1)),
563
)
564
565
tuner.results_summary()
566
567
"""
568
### Custom metric as the objective
569
570
You may implement your own metric and use it as the hyperparameter search
571
objective. Here, we use mean squared error (MSE) as an example. First, we
572
implement the MSE metric by subclassing `keras.metrics.Metric`. Remember to
573
give a name to your metric using the `name` argument of `super().__init__()`,
574
which will be used later. Note: MSE is actually a build-in metric, which can be
575
imported with `keras.metrics.MeanSquaredError`. This is just an example to show
576
how to use a custom metric as the hyperparameter search objective.
577
578
For more information about implementing custom metrics, please see [this
579
tutorial](https://keras.io/api/metrics/#creating-custom-metrics). If you would
580
like a metric with a different function signature than `update_state(y_true,
581
y_pred, sample_weight)`, you can override the `train_step()` method of your
582
model following [this
583
tutorial](https://keras.io/guides/customizing_what_happens_in_fit/#going-lowerlevel).
584
585
"""
586
587
from keras import ops
588
589
590
class CustomMetric(keras.metrics.Metric):
591
def __init__(self, **kwargs):
592
# Specify the name of the metric as "custom_metric".
593
super().__init__(name="custom_metric", **kwargs)
594
self.sum = self.add_weight(name="sum", initializer="zeros")
595
self.count = self.add_weight(name="count", dtype="int32", initializer="zeros")
596
597
def update_state(self, y_true, y_pred, sample_weight=None):
598
values = ops.square(y_true - y_pred)
599
count = ops.shape(y_true)[0]
600
if sample_weight is not None:
601
sample_weight = ops.cast(sample_weight, self.dtype)
602
values *= sample_weight
603
count *= sample_weight
604
self.sum.assign_add(ops.sum(values))
605
self.count.assign_add(count)
606
607
def result(self):
608
return self.sum / ops.cast(self.count, "float32")
609
610
def reset_state(self):
611
self.sum.assign(0)
612
self.count.assign(0)
613
614
615
"""
616
Run the search with the custom objective.
617
"""
618
619
620
def build_regressor(hp):
621
model = keras.Sequential(
622
[
623
layers.Dense(units=hp.Int("units", 32, 128, 32), activation="relu"),
624
layers.Dense(units=1),
625
]
626
)
627
model.compile(
628
optimizer="adam",
629
loss="mean_squared_error",
630
# Put custom metric into the metrics.
631
metrics=[CustomMetric()],
632
)
633
return model
634
635
636
tuner = keras_tuner.RandomSearch(
637
hypermodel=build_regressor,
638
# Specify the name and direction of the objective.
639
objective=keras_tuner.Objective("val_custom_metric", direction="min"),
640
max_trials=3,
641
overwrite=True,
642
directory="my_dir",
643
project_name="custom_metrics",
644
)
645
646
tuner.search(
647
x=np.random.rand(100, 10),
648
y=np.random.rand(100, 1),
649
validation_data=(np.random.rand(20, 10), np.random.rand(20, 1)),
650
)
651
652
tuner.results_summary()
653
654
"""
655
If your custom objective is hard to put into a custom metric, you can also
656
evaluate the model by yourself in `HyperModel.fit()` and return the objective
657
value. The objective value would be minimized by default. In this case, you
658
don't need to specify the `objective` when initializing the tuner. However, in
659
this case, the metric value will not be tracked in the Keras logs by only
660
KerasTuner logs. Therefore, these values would not be displayed by any
661
TensorBoard view using the Keras metrics.
662
"""
663
664
665
class HyperRegressor(keras_tuner.HyperModel):
666
def build(self, hp):
667
model = keras.Sequential(
668
[
669
layers.Dense(units=hp.Int("units", 32, 128, 32), activation="relu"),
670
layers.Dense(units=1),
671
]
672
)
673
model.compile(
674
optimizer="adam",
675
loss="mean_squared_error",
676
)
677
return model
678
679
def fit(self, hp, model, x, y, validation_data, **kwargs):
680
model.fit(x, y, **kwargs)
681
x_val, y_val = validation_data
682
y_pred = model.predict(x_val)
683
# Return a single float to minimize.
684
return np.mean(np.abs(y_pred - y_val))
685
686
687
tuner = keras_tuner.RandomSearch(
688
hypermodel=HyperRegressor(),
689
# No objective to specify.
690
# Objective is the return value of `HyperModel.fit()`.
691
max_trials=3,
692
overwrite=True,
693
directory="my_dir",
694
project_name="custom_eval",
695
)
696
tuner.search(
697
x=np.random.rand(100, 10),
698
y=np.random.rand(100, 1),
699
validation_data=(np.random.rand(20, 10), np.random.rand(20, 1)),
700
)
701
702
tuner.results_summary()
703
704
"""
705
If you have multiple metrics to track in KerasTuner, but only use one of them
706
as the objective, you can return a dictionary, whose keys are the metric names
707
and the values are the metrics values, for example, return `{"metric_a": 1.0,
708
"metric_b", 2.0}`. Use one of the keys as the objective name, for example,
709
`keras_tuner.Objective("metric_a", "min")`.
710
"""
711
712
713
class HyperRegressor(keras_tuner.HyperModel):
714
def build(self, hp):
715
model = keras.Sequential(
716
[
717
layers.Dense(units=hp.Int("units", 32, 128, 32), activation="relu"),
718
layers.Dense(units=1),
719
]
720
)
721
model.compile(
722
optimizer="adam",
723
loss="mean_squared_error",
724
)
725
return model
726
727
def fit(self, hp, model, x, y, validation_data, **kwargs):
728
model.fit(x, y, **kwargs)
729
x_val, y_val = validation_data
730
y_pred = model.predict(x_val)
731
# Return a dictionary of metrics for KerasTuner to track.
732
return {
733
"metric_a": -np.mean(np.abs(y_pred - y_val)),
734
"metric_b": np.mean(np.square(y_pred - y_val)),
735
}
736
737
738
tuner = keras_tuner.RandomSearch(
739
hypermodel=HyperRegressor(),
740
# Objective is one of the keys.
741
# Maximize the negative MAE, equivalent to minimize MAE.
742
objective=keras_tuner.Objective("metric_a", "max"),
743
max_trials=3,
744
overwrite=True,
745
directory="my_dir",
746
project_name="custom_eval_dict",
747
)
748
tuner.search(
749
x=np.random.rand(100, 10),
750
y=np.random.rand(100, 1),
751
validation_data=(np.random.rand(20, 10), np.random.rand(20, 1)),
752
)
753
754
tuner.results_summary()
755
756
"""
757
## Tune end-to-end workflows
758
759
In some cases, it is hard to align your code into build and fit functions. You
760
can also keep your end-to-end workflow in one place by overriding
761
`Tuner.run_trial()`, which gives you full control of a trial. You can see it
762
as a black-box optimizer for anything.
763
764
### Tune any function
765
766
For example, you can find a value of `x`, which minimizes `f(x)=x*x+1`. In the
767
following code, we just define `x` as a hyperparameter, and return `f(x)` as
768
the objective value. The `hypermodel` and `objective` argument for initializing
769
the tuner can be omitted.
770
"""
771
772
773
class MyTuner(keras_tuner.RandomSearch):
774
def run_trial(self, trial, *args, **kwargs):
775
# Get the hp from trial.
776
hp = trial.hyperparameters
777
# Define "x" as a hyperparameter.
778
x = hp.Float("x", min_value=-1.0, max_value=1.0)
779
# Return the objective value to minimize.
780
return x * x + 1
781
782
783
tuner = MyTuner(
784
# No hypermodel or objective specified.
785
max_trials=20,
786
overwrite=True,
787
directory="my_dir",
788
project_name="tune_anything",
789
)
790
791
# No need to pass anything to search()
792
# unless you use them in run_trial().
793
tuner.search()
794
print(tuner.get_best_hyperparameters()[0].get("x"))
795
796
"""
797
### Keep Keras code separate
798
799
You can keep all your Keras code unchanged and use KerasTuner to tune it. It
800
is useful if you cannot modify the Keras code for some reason.
801
802
It also gives you more flexibility. You don't have to separate the model
803
building and training code apart. However, this workflow would not help you
804
save the model or connect with the TensorBoard plugins.
805
806
To save the model, you can use `trial.trial_id`, which is a string to uniquely
807
identify a trial, to construct different paths to save the models from
808
different trials.
809
"""
810
811
import os
812
813
814
def keras_code(units, optimizer, saving_path):
815
# Build model
816
model = keras.Sequential(
817
[
818
layers.Dense(units=units, activation="relu"),
819
layers.Dense(units=1),
820
]
821
)
822
model.compile(
823
optimizer=optimizer,
824
loss="mean_squared_error",
825
)
826
827
# Prepare data
828
x_train = np.random.rand(100, 10)
829
y_train = np.random.rand(100, 1)
830
x_val = np.random.rand(20, 10)
831
y_val = np.random.rand(20, 1)
832
833
# Train & eval model
834
model.fit(x_train, y_train)
835
836
# Save model
837
model.save(saving_path)
838
839
# Return a single float as the objective value.
840
# You may also return a dictionary
841
# of {metric_name: metric_value}.
842
y_pred = model.predict(x_val)
843
return np.mean(np.abs(y_pred - y_val))
844
845
846
class MyTuner(keras_tuner.RandomSearch):
847
def run_trial(self, trial, **kwargs):
848
hp = trial.hyperparameters
849
return keras_code(
850
units=hp.Int("units", 32, 128, 32),
851
optimizer=hp.Choice("optimizer", ["adam", "adadelta"]),
852
saving_path=os.path.join("/tmp", f"{trial.trial_id}.keras"),
853
)
854
855
856
tuner = MyTuner(
857
max_trials=3,
858
overwrite=True,
859
directory="my_dir",
860
project_name="keep_code_separate",
861
)
862
tuner.search()
863
# Retraining the model
864
best_hp = tuner.get_best_hyperparameters()[0]
865
keras_code(**best_hp.values, saving_path="/tmp/best_model.keras")
866
867
"""
868
## KerasTuner includes pre-made tunable applications: HyperResNet and HyperXception
869
870
These are ready-to-use hypermodels for computer vision.
871
872
They come pre-compiled with `loss="categorical_crossentropy"` and
873
`metrics=["accuracy"]`.
874
875
"""
876
877
from keras_tuner.applications import HyperResNet
878
879
hypermodel = HyperResNet(input_shape=(28, 28, 1), classes=10)
880
881
tuner = keras_tuner.RandomSearch(
882
hypermodel,
883
objective="val_accuracy",
884
max_trials=2,
885
overwrite=True,
886
directory="my_dir",
887
project_name="built_in_hypermodel",
888
)
889
890