Path: blob/master/guides/ipynb/intro_to_keras_for_engineers.ipynb
3283 views
Introduction to Keras for engineers
Author: fchollet
Date created: 2023/07/10
Last modified: 2023/07/10
Description: First contact with Keras 3.
Introduction
Keras 3 is a deep learning framework works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras 3 workflows.
Let's start by installing Keras 3:
Setup
We're going to be using the JAX backend here -- but you can edit the string below to "tensorflow"
or "torch"
and hit "Restart runtime", and the whole notebook will run just the same! This entire guide is backend-agnostic.
A first example: A MNIST convnet
Let's start with the Hello World of ML: training a convnet to classify MNIST digits.
Here's the data:
Here's our model.
Different model-building options that Keras offers include:
The Sequential API (what we use below)
The Functional API (most typical)
Writing your own models yourself via subclassing (for advanced use cases)
Here's our model summary:
We use the compile()
method to specify the optimizer, loss function, and the metrics to monitor. Note that with the JAX and TensorFlow backends, XLA compilation is turned on by default.
Let's train and evaluate the model. We'll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.
During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:
And reload it like this:
Next, you can query predictions of class probabilities with predict()
:
That's it for the basics!
Writing cross-framework custom components
Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look at custom layers first.
The keras.ops
namespace contains:
An implementation of the NumPy API, e.g.
keras.ops.stack
orkeras.ops.matmul
.A set of neural network specific ops that are absent from NumPy, such as
keras.ops.conv
orkeras.ops.binary_crossentropy
.
Let's make a custom Dense
layer that works with all backends:
Next, let's make a custom Dropout
layer that relies on the keras.random
namespace:
Next, let's write a custom subclassed model that uses our two custom layers:
Let's compile it and fit it:
Training models on arbitrary data sources
All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you're using. This includes:
NumPy arrays
Pandas dataframes
TensorFlow
tf.data.Dataset
objectsPyTorch
DataLoader
objectsKeras
PyDataset
objects
They all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.
Let's try it out with PyTorch DataLoaders
:
Now let's try this out with tf.data
:
Further reading
This concludes our short overview of the new multi-backend capabilities of Keras 3. Next, you can learn about:
How to customize what happens in fit()
Want to implement a non-standard training algorithm yourself but still want to benefit from the power and usability of fit()
? It's easy to customize fit()
to support arbitrary use cases:
How to write custom training loops
How to distribute training
Enjoy the library! 🚀