Path: blob/master/site/en-snapshot/guide/keras/custom_layers_and_models.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
Making new Layers and Models via subclassing
Setup
The Layer
class: the combination of state (weights) and some computation
One of the central abstraction in Keras is the Layer
class. A layer encapsulates both a state (the layer's "weights") and a transformation from inputs to outputs (a "call", the layer's forward pass).
Here's a densely-connected layer. It has a state: the variables w
and b
.
You would use a layer by calling it on some tensor input(s), much like a Python function.
Note that the weights w
and b
are automatically tracked by the layer upon being set as layer attributes:
Note you also have access to a quicker shortcut for adding weight to a layer: the add_weight()
method:
Layers can have non-trainable weights
Besides trainable weights, you can add non-trainable weights to a layer as well. Such weights are meant not to be taken into account during backpropagation, when you are training the layer.
Here's how to add and use a non-trainable weight:
It's part of layer.weights
, but it gets categorized as a non-trainable weight:
Best practice: deferring weight creation until the shape of the inputs is known
Our Linear
layer above took an input_dim
argument that was used to compute the shape of the weights w
and b
in __init__()
:
In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.
In the Keras API, we recommend creating layer weights in the build(self, inputs_shape)
method of your layer. Like this:
The __call__()
method of your layer will automatically run build the first time it is called. You now have a layer that's lazy and thus easier to use:
Implementing build()
separately as shown above nicely separates creating weights only once from using weights in every call. However, for some advanced custom layers, it can become impractical to separate the state creation and computation. Layer implementers are allowed to defer weight creation to the first __call__()
, but need to take care that later calls use the same weights. In addition, since __call__()
is likely to be executed for the first time inside a tf.function
, any variable creation that takes place in __call__()
should be wrapped in atf.init_scope
.
Layers are recursively composable
If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer.
We recommend creating such sublayers in the __init__()
method and leave it to the first __call__()
to trigger building their weights.
The add_loss()
method
When writing the call()
method of a layer, you can create loss tensors that you will want to use later, when writing your training loop. This is doable by calling self.add_loss(value)
:
These losses (including those created by any inner layer) can be retrieved via layer.losses
. This property is reset at the start of every __call__()
to the top-level layer, so that layer.losses
always contains the loss values created during the last forward pass.
In addition, the loss
property also contains regularization losses created for the weights of any inner layer:
These losses are meant to be taken into account when writing training loops, like this:
For a detailed guide about writing training loops, see the guide to writing a training loop from scratch.
These losses also work seamlessly with fit()
(they get automatically summed and added to the main loss, if any):
The add_metric()
method
Similarly to add_loss()
, layers also have an add_metric()
method for tracking the moving average of a quantity during training.
Consider the following layer: a "logistic endpoint" layer. It takes as inputs predictions & targets, it computes a loss which it tracks via add_loss()
, and it computes an accuracy scalar, which it tracks via add_metric()
.
Metrics tracked in this way are accessible via layer.metrics
:
Just like for add_loss()
, these metrics are tracked by fit()
:
You can optionally enable serialization on your layers
If you need your custom layers to be serializable as part of a Functional model, you can optionally implement a get_config()
method:
Note that the __init__()
method of the base Layer
class takes some keyword arguments, in particular a name
and a dtype
. It's good practice to pass these arguments to the parent class in __init__()
and to include them in the layer config:
If you need more flexibility when deserializing the layer from its config, you can also override the from_config()
class method. This is the base implementation of from_config()
:
To learn more about serialization and saving, see the complete guide to saving and serializing models.
Privileged training
argument in the call()
method
Some layers, in particular the BatchNormalization
layer and the Dropout
layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a training
(boolean) argument in the call()
method.
By exposing this argument in call()
, you enable the built-in training and evaluation loops (e.g. fit()
) to correctly use the layer in training and inference.
Privileged mask
argument in the call()
method
The other privileged argument supported by call()
is the mask
argument.
You will find it in all Keras RNN layers. A mask is a boolean tensor (one boolean value per timestep in the input) used to skip certain input timesteps when processing timeseries data.
Keras will automatically pass the correct mask
argument to __call__()
for layers that support it, when a mask is generated by a prior layer. Mask-generating layers are the Embedding
layer configured with mask_zero=True
, and the Masking
layer.
To learn more about masking and how to write masking-enabled layers, please check out the guide "understanding padding and masking".
The Model
class
In general, you will use the Layer
class to define inner computation blocks, and will use the Model
class to define the outer model -- the object you will train.
For instance, in a ResNet50 model, you would have several ResNet blocks subclassing Layer
, and a single Model
encompassing the entire ResNet50 network.
The Model
class has the same API as Layer
, with the following differences:
It exposes built-in training, evaluation, and prediction loops (
model.fit()
,model.evaluate()
,model.predict()
).It exposes the list of its inner layers, via the
model.layers
property.It exposes saving and serialization APIs (
save()
,save_weights()
...)
Effectively, the Layer
class corresponds to what we refer to in the literature as a "layer" (as in "convolution layer" or "recurrent layer") or as a "block" (as in "ResNet block" or "Inception block").
Meanwhile, the Model
class corresponds to what is referred to in the literature as a "model" (as in "deep learning model") or as a "network" (as in "deep neural network").
So if you're wondering, "should I use the Layer
class or the Model
class?", ask yourself: will I need to call fit()
on it? Will I need to call save()
on it? If so, go with Model
. If not (either because your class is just a block in a bigger system, or because you are writing training & saving code yourself), use Layer
.
For instance, we could take our mini-resnet example above, and use it to build a Model
that we could train with fit()
, and that we could save with save_weights()
:
Putting it all together: an end-to-end example
Here's what you've learned so far:
A
Layer
encapsulate a state (created in__init__()
orbuild()
) and some computation (defined incall()
).Layers can be recursively nested to create new, bigger computation blocks.
Layers can create and track losses (typically regularization losses) as well as metrics, via
add_loss()
andadd_metric()
The outer container, the thing you want to train, is a
Model
. AModel
is just like aLayer
, but with added training and serialization utilities.
Let's put all of these things together into an end-to-end example: we're going to implement a Variational AutoEncoder (VAE). We'll train it on MNIST digits.
Our VAE will be a subclass of Model
, built as a nested composition of layers that subclass Layer
. It will feature a regularization loss (KL divergence).
Let's write a simple training loop on MNIST:
Note that since the VAE is subclassing Model
, it features built-in training loops. So you could also have trained it like this:
Beyond object-oriented development: the Functional API
Was this example too much object-oriented development for you? You can also build models using the Functional API. Importantly, choosing one style or another does not prevent you from leveraging components written in the other style: you can always mix-and-match.
For instance, the Functional API example below reuses the same Sampling
layer we defined in the example above:
For more information, make sure to read the Functional API guide.