Path: blob/master/guides/making_new_layers_and_models_via_subclassing.py
3273 views
"""1Title: Making new layers and models via subclassing2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/03/014Last modified: 2023/06/255Description: Complete guide to writing `Layer` and `Model` objects from scratch.6Accelerator: None7"""89"""10## Introduction1112This guide will cover everything you need to know to build your own13subclassed layers and models. In particular, you'll learn about the following features:1415- The `Layer` class16- The `add_weight()` method17- Trainable and non-trainable weights18- The `build()` method19- Making sure your layers can be used with any backend20- The `add_loss()` method21- The `training` argument in `call()`22- The `mask` argument in `call()`23- Making sure your layers can be serialized2425Let's dive in.26"""27"""28## Setup29"""3031import numpy as np32import keras33from keras import ops34from keras import layers3536"""37## The `Layer` class: the combination of state (weights) and some computation3839One of the central abstractions in Keras is the `Layer` class. A layer40encapsulates both a state (the layer's "weights") and a transformation from41inputs to outputs (a "call", the layer's forward pass).4243Here's a densely-connected layer. It has two state variables:44the variables `w` and `b`.45"""464748class Linear(keras.layers.Layer):49def __init__(self, units=32, input_dim=32):50super().__init__()51self.w = self.add_weight(52shape=(input_dim, units),53initializer="random_normal",54trainable=True,55)56self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)5758def call(self, inputs):59return ops.matmul(inputs, self.w) + self.b606162"""63You would use a layer by calling it on some tensor input(s), much like a Python64function.65"""6667x = ops.ones((2, 2))68linear_layer = Linear(4, 2)69y = linear_layer(x)70print(y)7172"""73Note that the weights `w` and `b` are automatically tracked by the layer upon74being set as layer attributes:75"""7677assert linear_layer.weights == [linear_layer.w, linear_layer.b]7879"""80## Layers can have non-trainable weights8182Besides trainable weights, you can add non-trainable weights to a layer as83well. Such weights are meant not to be taken into account during84backpropagation, when you are training the layer.8586Here's how to add and use a non-trainable weight:87"""888990class ComputeSum(keras.layers.Layer):91def __init__(self, input_dim):92super().__init__()93self.total = self.add_weight(94initializer="zeros", shape=(input_dim,), trainable=False95)9697def call(self, inputs):98self.total.assign_add(ops.sum(inputs, axis=0))99return self.total100101102x = ops.ones((2, 2))103my_sum = ComputeSum(2)104y = my_sum(x)105print(y.numpy())106y = my_sum(x)107print(y.numpy())108109"""110It's part of `layer.weights`, but it gets categorized as a non-trainable weight:111"""112113print("weights:", len(my_sum.weights))114print("non-trainable weights:", len(my_sum.non_trainable_weights))115116# It's not included in the trainable weights:117print("trainable_weights:", my_sum.trainable_weights)118119"""120## Best practice: deferring weight creation until the shape of the inputs is known121122Our `Linear` layer above took an `input_dim` argument that was used to compute123the shape of the weights `w` and `b` in `__init__()`:124"""125126127class Linear(keras.layers.Layer):128def __init__(self, units=32, input_dim=32):129super().__init__()130self.w = self.add_weight(131shape=(input_dim, units),132initializer="random_normal",133trainable=True,134)135self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)136137def call(self, inputs):138return ops.matmul(inputs, self.w) + self.b139140141"""142In many cases, you may not know in advance the size of your inputs, and you143would like to lazily create weights when that value becomes known, some time144after instantiating the layer.145146In the Keras API, we recommend creating layer weights in the147`build(self, inputs_shape)` method of your layer. Like this:148"""149150151class Linear(keras.layers.Layer):152def __init__(self, units=32):153super().__init__()154self.units = units155156def build(self, input_shape):157self.w = self.add_weight(158shape=(input_shape[-1], self.units),159initializer="random_normal",160trainable=True,161)162self.b = self.add_weight(163shape=(self.units,), initializer="random_normal", trainable=True164)165166def call(self, inputs):167return ops.matmul(inputs, self.w) + self.b168169170"""171The `__call__()` method of your layer will automatically run build the first time172it is called. You now have a layer that's lazy and thus easier to use:173"""174175# At instantiation, we don't know on what inputs this is going to get called176linear_layer = Linear(32)177178# The layer's weights are created dynamically the first time the layer is called179y = linear_layer(x)180181"""182Implementing `build()` separately as shown above nicely separates creating weights183only once from using weights in every call.184"""185186"""187## Layers are recursively composable188189If you assign a Layer instance as an attribute of another Layer, the outer layer190will start tracking the weights created by the inner layer.191192We recommend creating such sublayers in the `__init__()` method and leave it to193the first `__call__()` to trigger building their weights.194"""195196197class MLPBlock(keras.layers.Layer):198def __init__(self):199super().__init__()200self.linear_1 = Linear(32)201self.linear_2 = Linear(32)202self.linear_3 = Linear(1)203204def call(self, inputs):205x = self.linear_1(inputs)206x = keras.activations.relu(x)207x = self.linear_2(x)208x = keras.activations.relu(x)209return self.linear_3(x)210211212mlp = MLPBlock()213y = mlp(ops.ones(shape=(3, 64))) # The first call to the `mlp` will create the weights214print("weights:", len(mlp.weights))215print("trainable weights:", len(mlp.trainable_weights))216217"""218## Backend-agnostic layers and backend-specific layers219220As long as a layer only uses APIs from the `keras.ops` namespace221(or other Keras namespaces such as `keras.activations`, `keras.random`, or `keras.layers`),222then it can be used with any backend -- TensorFlow, JAX, or PyTorch.223224All layers you've seen so far in this guide work with all Keras backends.225226The `keras.ops` namespace gives you access to:227228- The NumPy API, e.g. `ops.matmul`, `ops.sum`, `ops.reshape`, `ops.stack`, etc.229- Neural networks-specific APIs such as `ops.softmax`, `ops.conv`, `ops.binary_crossentropy`, `ops.relu`, etc.230231You can also use backend-native APIs in your layers (such as `tf.nn` functions),232but if you do this, then your layer will only be usable with the backend in question.233For instance, you could write the following JAX-specific layer using `jax.numpy`:234235```python236import jax237238class Linear(keras.layers.Layer):239...240241def call(self, inputs):242return jax.numpy.matmul(inputs, self.w) + self.b243```244245This would be the equivalent TensorFlow-specific layer:246247```python248import tensorflow as tf249250class Linear(keras.layers.Layer):251...252253def call(self, inputs):254return tf.matmul(inputs, self.w) + self.b255```256257And this would be the equivalent PyTorch-specific layer:258259```python260import torch261262class Linear(keras.layers.Layer):263...264265def call(self, inputs):266return torch.matmul(inputs, self.w) + self.b267```268269Because cross-backend compatibility is a tremendously useful property, we strongly270recommend that you seek to always make your layers backend-agnostic by leveraging271only Keras APIs.272"""273274"""275## The `add_loss()` method276277When writing the `call()` method of a layer, you can create loss tensors that278you will want to use later, when writing your training loop. This is doable by279calling `self.add_loss(value)`:280"""281282283# A layer that creates an activity regularization loss284class ActivityRegularizationLayer(keras.layers.Layer):285def __init__(self, rate=1e-2):286super().__init__()287self.rate = rate288289def call(self, inputs):290self.add_loss(self.rate * ops.mean(inputs))291return inputs292293294"""295These losses (including those created by any inner layer) can be retrieved via296`layer.losses`. This property is reset at the start of every `__call__()` to297the top-level layer, so that `layer.losses` always contains the loss values298created during the last forward pass.299"""300301302class OuterLayer(keras.layers.Layer):303def __init__(self):304super().__init__()305self.activity_reg = ActivityRegularizationLayer(1e-2)306307def call(self, inputs):308return self.activity_reg(inputs)309310311layer = OuterLayer()312assert len(layer.losses) == 0 # No losses yet since the layer has never been called313314_ = layer(ops.zeros((1, 1)))315assert len(layer.losses) == 1 # We created one loss value316317# `layer.losses` gets reset at the start of each __call__318_ = layer(ops.zeros((1, 1)))319assert len(layer.losses) == 1 # This is the loss created during the call above320321"""322In addition, the `loss` property also contains regularization losses created323for the weights of any inner layer:324"""325326327class OuterLayerWithKernelRegularizer(keras.layers.Layer):328def __init__(self):329super().__init__()330self.dense = keras.layers.Dense(33132, kernel_regularizer=keras.regularizers.l2(1e-3)332)333334def call(self, inputs):335return self.dense(inputs)336337338layer = OuterLayerWithKernelRegularizer()339_ = layer(ops.zeros((1, 1)))340341# This is `1e-3 * sum(layer.dense.kernel ** 2)`,342# created by the `kernel_regularizer` above.343print(layer.losses)344345"""346These losses are meant to be taken into account when writing custom training loops.347348They also work seamlessly with `fit()` (they get automatically summed and added to the main loss, if any):349"""350351inputs = keras.Input(shape=(3,))352outputs = ActivityRegularizationLayer()(inputs)353model = keras.Model(inputs, outputs)354355# If there is a loss passed in `compile`, the regularization356# losses get added to it357model.compile(optimizer="adam", loss="mse")358model.fit(np.random.random((2, 3)), np.random.random((2, 3)))359360# It's also possible not to pass any loss in `compile`,361# since the model already has a loss to minimize, via the `add_loss`362# call during the forward pass!363model.compile(optimizer="adam")364model.fit(np.random.random((2, 3)), np.random.random((2, 3)))365366"""367## You can optionally enable serialization on your layers368369If you need your custom layers to be serializable as part of a370[Functional model](/guides/functional_api/),371you can optionally implement a `get_config()` method:372"""373374375class Linear(keras.layers.Layer):376def __init__(self, units=32):377super().__init__()378self.units = units379380def build(self, input_shape):381self.w = self.add_weight(382shape=(input_shape[-1], self.units),383initializer="random_normal",384trainable=True,385)386self.b = self.add_weight(387shape=(self.units,), initializer="random_normal", trainable=True388)389390def call(self, inputs):391return ops.matmul(inputs, self.w) + self.b392393def get_config(self):394return {"units": self.units}395396397# Now you can recreate the layer from its config:398layer = Linear(64)399config = layer.get_config()400print(config)401new_layer = Linear.from_config(config)402403"""404Note that the `__init__()` method of the base `Layer` class takes some keyword405arguments, in particular a `name` and a `dtype`. It's good practice to pass406these arguments to the parent class in `__init__()` and to include them in the407layer config:408"""409410411class Linear(keras.layers.Layer):412def __init__(self, units=32, **kwargs):413super().__init__(**kwargs)414self.units = units415416def build(self, input_shape):417self.w = self.add_weight(418shape=(input_shape[-1], self.units),419initializer="random_normal",420trainable=True,421)422self.b = self.add_weight(423shape=(self.units,), initializer="random_normal", trainable=True424)425426def call(self, inputs):427return ops.matmul(inputs, self.w) + self.b428429def get_config(self):430config = super().get_config()431config.update({"units": self.units})432return config433434435layer = Linear(64)436config = layer.get_config()437print(config)438new_layer = Linear.from_config(config)439440"""441If you need more flexibility when deserializing the layer from its config, you442can also override the `from_config()` class method. This is the base443implementation of `from_config()`:444445```python446def from_config(cls, config):447return cls(**config)448```449450To learn more about serialization and saving, see the complete451[guide to saving and serializing models](/guides/serialization_and_saving/).452"""453454"""455## Privileged `training` argument in the `call()` method456457Some layers, in particular the `BatchNormalization` layer and the `Dropout`458layer, have different behaviors during training and inference. For such459layers, it is standard practice to expose a `training` (boolean) argument in460the `call()` method.461462By exposing this argument in `call()`, you enable the built-in training and463evaluation loops (e.g. `fit()`) to correctly use the layer in training and464inference.465"""466467468class CustomDropout(keras.layers.Layer):469def __init__(self, rate, **kwargs):470super().__init__(**kwargs)471self.rate = rate472self.seed_generator = keras.random.SeedGenerator(1337)473474def call(self, inputs, training=None):475if training:476return keras.random.dropout(477inputs, rate=self.rate, seed=self.seed_generator478)479return inputs480481482"""483## Privileged `mask` argument in the `call()` method484485The other privileged argument supported by `call()` is the `mask` argument.486487You will find it in all Keras RNN layers. A mask is a boolean tensor (one488boolean value per timestep in the input) used to skip certain input timesteps489when processing timeseries data.490491Keras will automatically pass the correct `mask` argument to `__call__()` for492layers that support it, when a mask is generated by a prior layer.493Mask-generating layers are the `Embedding`494layer configured with `mask_zero=True`, and the `Masking` layer.495"""496497"""498## The `Model` class499500In general, you will use the `Layer` class to define inner computation blocks,501and will use the `Model` class to define the outer model -- the object you502will train.503504For instance, in a ResNet50 model, you would have several ResNet blocks505subclassing `Layer`, and a single `Model` encompassing the entire ResNet50506network.507508The `Model` class has the same API as `Layer`, with the following differences:509510- It exposes built-in training, evaluation, and prediction loops511(`model.fit()`, `model.evaluate()`, `model.predict()`).512- It exposes the list of its inner layers, via the `model.layers` property.513- It exposes saving and serialization APIs (`save()`, `save_weights()`...)514515Effectively, the `Layer` class corresponds to what we refer to in the516literature as a "layer" (as in "convolution layer" or "recurrent layer") or as517a "block" (as in "ResNet block" or "Inception block").518519Meanwhile, the `Model` class corresponds to what is referred to in the520literature as a "model" (as in "deep learning model") or as a "network" (as in521"deep neural network").522523So if you're wondering, "should I use the `Layer` class or the `Model` class?",524ask yourself: will I need to call `fit()` on it? Will I need to call `save()`525on it? If so, go with `Model`. If not (either because your class is just a block526in a bigger system, or because you are writing training & saving code yourself),527use `Layer`.528529For instance, we could take our mini-resnet example above, and use it to build530a `Model` that we could train with `fit()`, and that we could save with531`save_weights()`:532"""533534"""535```python536class ResNet(keras.Model):537538def __init__(self, num_classes=1000):539super().__init__()540self.block_1 = ResNetBlock()541self.block_2 = ResNetBlock()542self.global_pool = layers.GlobalAveragePooling2D()543self.classifier = Dense(num_classes)544545def call(self, inputs):546x = self.block_1(inputs)547x = self.block_2(x)548x = self.global_pool(x)549return self.classifier(x)550551552resnet = ResNet()553dataset = ...554resnet.fit(dataset, epochs=10)555resnet.save(filepath.keras)556```557"""558559"""560## Putting it all together: an end-to-end example561562Here's what you've learned so far:563564- A `Layer` encapsulate a state (created in `__init__()` or `build()`) and some565computation (defined in `call()`).566- Layers can be recursively nested to create new, bigger computation blocks.567- Layers are backend-agnostic as long as they only use Keras APIs. You can use568backend-native APIs (such as `jax.numpy`, `torch.nn` or `tf.nn`), but then569your layer will only be usable with that specific backend.570- Layers can create and track losses (typically regularization losses)571via `add_loss()`.572- The outer container, the thing you want to train, is a `Model`. A `Model` is573just like a `Layer`, but with added training and serialization utilities.574575Let's put all of these things together into an end-to-end example: we're going576to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion577-- so that it runs the same with TensorFlow, JAX, and PyTorch.578We'll train it on MNIST digits.579580Our VAE will be a subclass of `Model`, built as a nested composition of layers581that subclass `Layer`. It will feature a regularization loss (KL divergence).582"""583584585class Sampling(layers.Layer):586"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""587588def __init__(self, **kwargs):589super().__init__(**kwargs)590self.seed_generator = keras.random.SeedGenerator(1337)591592def call(self, inputs):593z_mean, z_log_var = inputs594batch = ops.shape(z_mean)[0]595dim = ops.shape(z_mean)[1]596epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)597return z_mean + ops.exp(0.5 * z_log_var) * epsilon598599600class Encoder(layers.Layer):601"""Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""602603def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):604super().__init__(name=name, **kwargs)605self.dense_proj = layers.Dense(intermediate_dim, activation="relu")606self.dense_mean = layers.Dense(latent_dim)607self.dense_log_var = layers.Dense(latent_dim)608self.sampling = Sampling()609610def call(self, inputs):611x = self.dense_proj(inputs)612z_mean = self.dense_mean(x)613z_log_var = self.dense_log_var(x)614z = self.sampling((z_mean, z_log_var))615return z_mean, z_log_var, z616617618class Decoder(layers.Layer):619"""Converts z, the encoded digit vector, back into a readable digit."""620621def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):622super().__init__(name=name, **kwargs)623self.dense_proj = layers.Dense(intermediate_dim, activation="relu")624self.dense_output = layers.Dense(original_dim, activation="sigmoid")625626def call(self, inputs):627x = self.dense_proj(inputs)628return self.dense_output(x)629630631class VariationalAutoEncoder(keras.Model):632"""Combines the encoder and decoder into an end-to-end model for training."""633634def __init__(635self,636original_dim,637intermediate_dim=64,638latent_dim=32,639name="autoencoder",640**kwargs641):642super().__init__(name=name, **kwargs)643self.original_dim = original_dim644self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)645self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)646647def call(self, inputs):648z_mean, z_log_var, z = self.encoder(inputs)649reconstructed = self.decoder(z)650# Add KL divergence regularization loss.651kl_loss = -0.5 * ops.mean(652z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1653)654self.add_loss(kl_loss)655return reconstructed656657658"""659Let's train it on MNIST using the `fit()` API:660"""661662(x_train, _), _ = keras.datasets.mnist.load_data()663x_train = x_train.reshape(60000, 784).astype("float32") / 255664665original_dim = 784666vae = VariationalAutoEncoder(784, 64, 32)667668optimizer = keras.optimizers.Adam(learning_rate=1e-3)669vae.compile(optimizer, loss=keras.losses.MeanSquaredError())670671vae.fit(x_train, x_train, epochs=2, batch_size=64)672673674