Path: blob/master/keras/nn_keras_hyperparameter_tuning.ipynb
1470 views
Table of Contents
Keras Hyperparameter Tuning
We'll use MNIST dataset. The downloaded data is split into three parts, 55,000 data points of training data (mnist.train), 10,000 points of test data (mnist.test), and 5,000 points of validation data (mnist.validation).
Every part of the dataset contains the data and label and we can access them via .images and .labels. e.g. the training images are mnist.train.images and the train labels are mnist.train.labels (one-hot encoded).
Keras provides a wrapper class KerasClassifier
that allows us to use our deep learning models with scikit-learn, this is especially useful when you want to tune hyperparameters using scikit-learn's RandomizedSearchCV or GridSearchCV.
To use it, we first define a function that takes the arguments that we wish to tune, inside the function, you define the network's structure as usual and compile it. Then the function is passed to KerasClassifier
's build_fn
parameter. Note that like all other estimators in scikit-learn, build_fn
should provide default values for its arguments, so that we could create the estimator even without passing in values for every parameters.