Path: blob/master/guides/intro_to_keras_for_engineers.py
3273 views
"""1Title: Introduction to Keras for engineers2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/07/104Last modified: 2023/07/105Description: First contact with Keras 3.6Accelerator: GPU7"""89"""10## Introduction1112Keras 3 is a deep learning framework13works with TensorFlow, JAX, and PyTorch interchangeably.14This notebook will walk you through key Keras 3 workflows.1516Let's start by installing Keras 3:17"""1819"""shell20pip install keras --upgrade --quiet21"""2223"""24## Setup2526We're going to be using the JAX backend here -- but you can27edit the string below to `"tensorflow"` or `"torch"` and hit28"Restart runtime", and the whole notebook will run just the same!29This entire guide is backend-agnostic.30"""3132import numpy as np33import os3435os.environ["KERAS_BACKEND"] = "jax"3637# Note that Keras should only be imported after the backend38# has been configured. The backend cannot be changed once the39# package is imported.40import keras4142"""43## A first example: A MNIST convnet4445Let's start with the Hello World of ML: training a convnet46to classify MNIST digits.4748Here's the data:49"""5051# Load the data and split it between train and test sets52(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()5354# Scale images to the [0, 1] range55x_train = x_train.astype("float32") / 25556x_test = x_test.astype("float32") / 25557# Make sure images have shape (28, 28, 1)58x_train = np.expand_dims(x_train, -1)59x_test = np.expand_dims(x_test, -1)60print("x_train shape:", x_train.shape)61print("y_train shape:", y_train.shape)62print(x_train.shape[0], "train samples")63print(x_test.shape[0], "test samples")6465"""66Here's our model.6768Different model-building options that Keras offers include:6970- [The Sequential API](https://keras.io/guides/sequential_model/) (what we use below)71- [The Functional API](https://keras.io/guides/functional_api/) (most typical)72- [Writing your own models yourself via subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) (for advanced use cases)73"""7475# Model parameters76num_classes = 1077input_shape = (28, 28, 1)7879model = keras.Sequential(80[81keras.layers.Input(shape=input_shape),82keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),83keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),84keras.layers.MaxPooling2D(pool_size=(2, 2)),85keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),86keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),87keras.layers.GlobalAveragePooling2D(),88keras.layers.Dropout(0.5),89keras.layers.Dense(num_classes, activation="softmax"),90]91)9293"""94Here's our model summary:95"""9697model.summary()9899"""100We use the `compile()` method to specify the optimizer, loss function,101and the metrics to monitor. Note that with the JAX and TensorFlow backends,102XLA compilation is turned on by default.103"""104105model.compile(106loss=keras.losses.SparseCategoricalCrossentropy(),107optimizer=keras.optimizers.Adam(learning_rate=1e-3),108metrics=[109keras.metrics.SparseCategoricalAccuracy(name="acc"),110],111)112113"""114Let's train and evaluate the model. We'll set aside a validation split of 15%115of the data during training to monitor generalization on unseen data.116"""117118batch_size = 128119epochs = 20120121callbacks = [122keras.callbacks.ModelCheckpoint(filepath="model_at_epoch_{epoch}.keras"),123keras.callbacks.EarlyStopping(monitor="val_loss", patience=2),124]125126model.fit(127x_train,128y_train,129batch_size=batch_size,130epochs=epochs,131validation_split=0.15,132callbacks=callbacks,133)134score = model.evaluate(x_test, y_test, verbose=0)135136"""137During training, we were saving a model at the end of each epoch. You138can also save the model in its latest state like this:139"""140141model.save("final_model.keras")142143"""144And reload it like this:145"""146147model = keras.saving.load_model("final_model.keras")148149"""150Next, you can query predictions of class probabilities with `predict()`:151"""152153predictions = model.predict(x_test)154155"""156That's it for the basics!157"""158159"""160## Writing cross-framework custom components161162Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers163that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look164at custom layers first.165166The `keras.ops` namespace contains:167168- An implementation of the NumPy API, e.g. `keras.ops.stack` or `keras.ops.matmul`.169- A set of neural network specific ops that are absent from NumPy, such as `keras.ops.conv`170or `keras.ops.binary_crossentropy`.171172Let's make a custom `Dense` layer that works with all backends:173"""174175176class MyDense(keras.layers.Layer):177def __init__(self, units, activation=None, name=None):178super().__init__(name=name)179self.units = units180self.activation = keras.activations.get(activation)181182def build(self, input_shape):183input_dim = input_shape[-1]184self.w = self.add_weight(185shape=(input_dim, self.units),186initializer=keras.initializers.GlorotNormal(),187name="kernel",188trainable=True,189)190191self.b = self.add_weight(192shape=(self.units,),193initializer=keras.initializers.Zeros(),194name="bias",195trainable=True,196)197198def call(self, inputs):199# Use Keras ops to create backend-agnostic layers/metrics/etc.200x = keras.ops.matmul(inputs, self.w) + self.b201return self.activation(x)202203204"""205Next, let's make a custom `Dropout` layer that relies on the `keras.random`206namespace:207"""208209210class MyDropout(keras.layers.Layer):211def __init__(self, rate, name=None):212super().__init__(name=name)213self.rate = rate214# Use seed_generator for managing RNG state.215# It is a state element and its seed variable is216# tracked as part of `layer.variables`.217self.seed_generator = keras.random.SeedGenerator(1337)218219def call(self, inputs):220# Use `keras.random` for random ops.221return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)222223224"""225Next, let's write a custom subclassed model that uses our two custom layers:226"""227228229class MyModel(keras.Model):230def __init__(self, num_classes):231super().__init__()232self.conv_base = keras.Sequential(233[234keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),235keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),236keras.layers.MaxPooling2D(pool_size=(2, 2)),237keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),238keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),239keras.layers.GlobalAveragePooling2D(),240]241)242self.dp = MyDropout(0.5)243self.dense = MyDense(num_classes, activation="softmax")244245def call(self, x):246x = self.conv_base(x)247x = self.dp(x)248return self.dense(x)249250251"""252Let's compile it and fit it:253"""254255model = MyModel(num_classes=10)256model.compile(257loss=keras.losses.SparseCategoricalCrossentropy(),258optimizer=keras.optimizers.Adam(learning_rate=1e-3),259metrics=[260keras.metrics.SparseCategoricalAccuracy(name="acc"),261],262)263264model.fit(265x_train,266y_train,267batch_size=batch_size,268epochs=1, # For speed269validation_split=0.15,270)271272"""273## Training models on arbitrary data sources274275All Keras models can be trained and evaluated on a wide variety of data sources,276independently of the backend you're using. This includes:277278- NumPy arrays279- Pandas dataframes280- TensorFlow `tf.data.Dataset` objects281- PyTorch `DataLoader` objects282- Keras `PyDataset` objects283284They all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.285286Let's try it out with PyTorch `DataLoaders`:287"""288289import torch290291# Create a TensorDataset292train_torch_dataset = torch.utils.data.TensorDataset(293torch.from_numpy(x_train), torch.from_numpy(y_train)294)295val_torch_dataset = torch.utils.data.TensorDataset(296torch.from_numpy(x_test), torch.from_numpy(y_test)297)298299# Create a DataLoader300train_dataloader = torch.utils.data.DataLoader(301train_torch_dataset, batch_size=batch_size, shuffle=True302)303val_dataloader = torch.utils.data.DataLoader(304val_torch_dataset, batch_size=batch_size, shuffle=False305)306307model = MyModel(num_classes=10)308model.compile(309loss=keras.losses.SparseCategoricalCrossentropy(),310optimizer=keras.optimizers.Adam(learning_rate=1e-3),311metrics=[312keras.metrics.SparseCategoricalAccuracy(name="acc"),313],314)315model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)316317318"""319Now let's try this out with `tf.data`:320"""321322import tensorflow as tf323324train_dataset = (325tf.data.Dataset.from_tensor_slices((x_train, y_train))326.batch(batch_size)327.prefetch(tf.data.AUTOTUNE)328)329test_dataset = (330tf.data.Dataset.from_tensor_slices((x_test, y_test))331.batch(batch_size)332.prefetch(tf.data.AUTOTUNE)333)334335model = MyModel(num_classes=10)336model.compile(337loss=keras.losses.SparseCategoricalCrossentropy(),338optimizer=keras.optimizers.Adam(learning_rate=1e-3),339metrics=[340keras.metrics.SparseCategoricalAccuracy(name="acc"),341],342)343model.fit(train_dataset, epochs=1, validation_data=test_dataset)344345"""346## Further reading347348This concludes our short overview of the new multi-backend capabilities349of Keras 3. Next, you can learn about:350351### How to customize what happens in `fit()`352353Want to implement a non-standard training algorithm yourself but still want to benefit from354the power and usability of `fit()`? It's easy to customize355`fit()` to support arbitrary use cases:356357- [Customizing what happens in `fit()` with TensorFlow](http://keras.io/guides/custom_train_step_in_tensorflow/)358- [Customizing what happens in `fit()` with JAX](http://keras.io/guides/custom_train_step_in_jax/)359- [Customizing what happens in `fit()` with PyTorch](http://keras.io/guides/custom_train_step_in_torch/)360361## How to write custom training loops362363- [Writing a training loop from scratch in TensorFlow](http://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/)364- [Writing a training loop from scratch in JAX](http://keras.io/guides/writing_a_custom_training_loop_in_jax/)365- [Writing a training loop from scratch in PyTorch](http://keras.io/guides/writing_a_custom_training_loop_in_torch/)366367## How to distribute training368369- [Guide to distributed training with TensorFlow](http://keras.io/guides/distributed_training_with_tensorflow/)370- [JAX distributed training example](https://github.com/keras-team/keras/blob/master/examples/demo_jax_distributed.py)371- [PyTorch distributed training example](https://github.com/keras-team/keras/blob/master/examples/demo_torch_multi_gpu.py)372373Enjoy the library! 🚀374"""375376377