Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/making_new_layers_and_models_via_subclassing.py
3273 views
1
"""
2
Title: Making new layers and models via subclassing
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/03/01
5
Last modified: 2023/06/25
6
Description: Complete guide to writing `Layer` and `Model` objects from scratch.
7
Accelerator: None
8
"""
9
10
"""
11
## Introduction
12
13
This guide will cover everything you need to know to build your own
14
subclassed layers and models. In particular, you'll learn about the following features:
15
16
- The `Layer` class
17
- The `add_weight()` method
18
- Trainable and non-trainable weights
19
- The `build()` method
20
- Making sure your layers can be used with any backend
21
- The `add_loss()` method
22
- The `training` argument in `call()`
23
- The `mask` argument in `call()`
24
- Making sure your layers can be serialized
25
26
Let's dive in.
27
"""
28
"""
29
## Setup
30
"""
31
32
import numpy as np
33
import keras
34
from keras import ops
35
from keras import layers
36
37
"""
38
## The `Layer` class: the combination of state (weights) and some computation
39
40
One of the central abstractions in Keras is the `Layer` class. A layer
41
encapsulates both a state (the layer's "weights") and a transformation from
42
inputs to outputs (a "call", the layer's forward pass).
43
44
Here's a densely-connected layer. It has two state variables:
45
the variables `w` and `b`.
46
"""
47
48
49
class Linear(keras.layers.Layer):
50
def __init__(self, units=32, input_dim=32):
51
super().__init__()
52
self.w = self.add_weight(
53
shape=(input_dim, units),
54
initializer="random_normal",
55
trainable=True,
56
)
57
self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
58
59
def call(self, inputs):
60
return ops.matmul(inputs, self.w) + self.b
61
62
63
"""
64
You would use a layer by calling it on some tensor input(s), much like a Python
65
function.
66
"""
67
68
x = ops.ones((2, 2))
69
linear_layer = Linear(4, 2)
70
y = linear_layer(x)
71
print(y)
72
73
"""
74
Note that the weights `w` and `b` are automatically tracked by the layer upon
75
being set as layer attributes:
76
"""
77
78
assert linear_layer.weights == [linear_layer.w, linear_layer.b]
79
80
"""
81
## Layers can have non-trainable weights
82
83
Besides trainable weights, you can add non-trainable weights to a layer as
84
well. Such weights are meant not to be taken into account during
85
backpropagation, when you are training the layer.
86
87
Here's how to add and use a non-trainable weight:
88
"""
89
90
91
class ComputeSum(keras.layers.Layer):
92
def __init__(self, input_dim):
93
super().__init__()
94
self.total = self.add_weight(
95
initializer="zeros", shape=(input_dim,), trainable=False
96
)
97
98
def call(self, inputs):
99
self.total.assign_add(ops.sum(inputs, axis=0))
100
return self.total
101
102
103
x = ops.ones((2, 2))
104
my_sum = ComputeSum(2)
105
y = my_sum(x)
106
print(y.numpy())
107
y = my_sum(x)
108
print(y.numpy())
109
110
"""
111
It's part of `layer.weights`, but it gets categorized as a non-trainable weight:
112
"""
113
114
print("weights:", len(my_sum.weights))
115
print("non-trainable weights:", len(my_sum.non_trainable_weights))
116
117
# It's not included in the trainable weights:
118
print("trainable_weights:", my_sum.trainable_weights)
119
120
"""
121
## Best practice: deferring weight creation until the shape of the inputs is known
122
123
Our `Linear` layer above took an `input_dim` argument that was used to compute
124
the shape of the weights `w` and `b` in `__init__()`:
125
"""
126
127
128
class Linear(keras.layers.Layer):
129
def __init__(self, units=32, input_dim=32):
130
super().__init__()
131
self.w = self.add_weight(
132
shape=(input_dim, units),
133
initializer="random_normal",
134
trainable=True,
135
)
136
self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
137
138
def call(self, inputs):
139
return ops.matmul(inputs, self.w) + self.b
140
141
142
"""
143
In many cases, you may not know in advance the size of your inputs, and you
144
would like to lazily create weights when that value becomes known, some time
145
after instantiating the layer.
146
147
In the Keras API, we recommend creating layer weights in the
148
`build(self, inputs_shape)` method of your layer. Like this:
149
"""
150
151
152
class Linear(keras.layers.Layer):
153
def __init__(self, units=32):
154
super().__init__()
155
self.units = units
156
157
def build(self, input_shape):
158
self.w = self.add_weight(
159
shape=(input_shape[-1], self.units),
160
initializer="random_normal",
161
trainable=True,
162
)
163
self.b = self.add_weight(
164
shape=(self.units,), initializer="random_normal", trainable=True
165
)
166
167
def call(self, inputs):
168
return ops.matmul(inputs, self.w) + self.b
169
170
171
"""
172
The `__call__()` method of your layer will automatically run build the first time
173
it is called. You now have a layer that's lazy and thus easier to use:
174
"""
175
176
# At instantiation, we don't know on what inputs this is going to get called
177
linear_layer = Linear(32)
178
179
# The layer's weights are created dynamically the first time the layer is called
180
y = linear_layer(x)
181
182
"""
183
Implementing `build()` separately as shown above nicely separates creating weights
184
only once from using weights in every call.
185
"""
186
187
"""
188
## Layers are recursively composable
189
190
If you assign a Layer instance as an attribute of another Layer, the outer layer
191
will start tracking the weights created by the inner layer.
192
193
We recommend creating such sublayers in the `__init__()` method and leave it to
194
the first `__call__()` to trigger building their weights.
195
"""
196
197
198
class MLPBlock(keras.layers.Layer):
199
def __init__(self):
200
super().__init__()
201
self.linear_1 = Linear(32)
202
self.linear_2 = Linear(32)
203
self.linear_3 = Linear(1)
204
205
def call(self, inputs):
206
x = self.linear_1(inputs)
207
x = keras.activations.relu(x)
208
x = self.linear_2(x)
209
x = keras.activations.relu(x)
210
return self.linear_3(x)
211
212
213
mlp = MLPBlock()
214
y = mlp(ops.ones(shape=(3, 64))) # The first call to the `mlp` will create the weights
215
print("weights:", len(mlp.weights))
216
print("trainable weights:", len(mlp.trainable_weights))
217
218
"""
219
## Backend-agnostic layers and backend-specific layers
220
221
As long as a layer only uses APIs from the `keras.ops` namespace
222
(or other Keras namespaces such as `keras.activations`, `keras.random`, or `keras.layers`),
223
then it can be used with any backend -- TensorFlow, JAX, or PyTorch.
224
225
All layers you've seen so far in this guide work with all Keras backends.
226
227
The `keras.ops` namespace gives you access to:
228
229
- The NumPy API, e.g. `ops.matmul`, `ops.sum`, `ops.reshape`, `ops.stack`, etc.
230
- Neural networks-specific APIs such as `ops.softmax`, `ops.conv`, `ops.binary_crossentropy`, `ops.relu`, etc.
231
232
You can also use backend-native APIs in your layers (such as `tf.nn` functions),
233
but if you do this, then your layer will only be usable with the backend in question.
234
For instance, you could write the following JAX-specific layer using `jax.numpy`:
235
236
```python
237
import jax
238
239
class Linear(keras.layers.Layer):
240
...
241
242
def call(self, inputs):
243
return jax.numpy.matmul(inputs, self.w) + self.b
244
```
245
246
This would be the equivalent TensorFlow-specific layer:
247
248
```python
249
import tensorflow as tf
250
251
class Linear(keras.layers.Layer):
252
...
253
254
def call(self, inputs):
255
return tf.matmul(inputs, self.w) + self.b
256
```
257
258
And this would be the equivalent PyTorch-specific layer:
259
260
```python
261
import torch
262
263
class Linear(keras.layers.Layer):
264
...
265
266
def call(self, inputs):
267
return torch.matmul(inputs, self.w) + self.b
268
```
269
270
Because cross-backend compatibility is a tremendously useful property, we strongly
271
recommend that you seek to always make your layers backend-agnostic by leveraging
272
only Keras APIs.
273
"""
274
275
"""
276
## The `add_loss()` method
277
278
When writing the `call()` method of a layer, you can create loss tensors that
279
you will want to use later, when writing your training loop. This is doable by
280
calling `self.add_loss(value)`:
281
"""
282
283
284
# A layer that creates an activity regularization loss
285
class ActivityRegularizationLayer(keras.layers.Layer):
286
def __init__(self, rate=1e-2):
287
super().__init__()
288
self.rate = rate
289
290
def call(self, inputs):
291
self.add_loss(self.rate * ops.mean(inputs))
292
return inputs
293
294
295
"""
296
These losses (including those created by any inner layer) can be retrieved via
297
`layer.losses`. This property is reset at the start of every `__call__()` to
298
the top-level layer, so that `layer.losses` always contains the loss values
299
created during the last forward pass.
300
"""
301
302
303
class OuterLayer(keras.layers.Layer):
304
def __init__(self):
305
super().__init__()
306
self.activity_reg = ActivityRegularizationLayer(1e-2)
307
308
def call(self, inputs):
309
return self.activity_reg(inputs)
310
311
312
layer = OuterLayer()
313
assert len(layer.losses) == 0 # No losses yet since the layer has never been called
314
315
_ = layer(ops.zeros((1, 1)))
316
assert len(layer.losses) == 1 # We created one loss value
317
318
# `layer.losses` gets reset at the start of each __call__
319
_ = layer(ops.zeros((1, 1)))
320
assert len(layer.losses) == 1 # This is the loss created during the call above
321
322
"""
323
In addition, the `loss` property also contains regularization losses created
324
for the weights of any inner layer:
325
"""
326
327
328
class OuterLayerWithKernelRegularizer(keras.layers.Layer):
329
def __init__(self):
330
super().__init__()
331
self.dense = keras.layers.Dense(
332
32, kernel_regularizer=keras.regularizers.l2(1e-3)
333
)
334
335
def call(self, inputs):
336
return self.dense(inputs)
337
338
339
layer = OuterLayerWithKernelRegularizer()
340
_ = layer(ops.zeros((1, 1)))
341
342
# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
343
# created by the `kernel_regularizer` above.
344
print(layer.losses)
345
346
"""
347
These losses are meant to be taken into account when writing custom training loops.
348
349
They also work seamlessly with `fit()` (they get automatically summed and added to the main loss, if any):
350
"""
351
352
inputs = keras.Input(shape=(3,))
353
outputs = ActivityRegularizationLayer()(inputs)
354
model = keras.Model(inputs, outputs)
355
356
# If there is a loss passed in `compile`, the regularization
357
# losses get added to it
358
model.compile(optimizer="adam", loss="mse")
359
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
360
361
# It's also possible not to pass any loss in `compile`,
362
# since the model already has a loss to minimize, via the `add_loss`
363
# call during the forward pass!
364
model.compile(optimizer="adam")
365
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
366
367
"""
368
## You can optionally enable serialization on your layers
369
370
If you need your custom layers to be serializable as part of a
371
[Functional model](/guides/functional_api/),
372
you can optionally implement a `get_config()` method:
373
"""
374
375
376
class Linear(keras.layers.Layer):
377
def __init__(self, units=32):
378
super().__init__()
379
self.units = units
380
381
def build(self, input_shape):
382
self.w = self.add_weight(
383
shape=(input_shape[-1], self.units),
384
initializer="random_normal",
385
trainable=True,
386
)
387
self.b = self.add_weight(
388
shape=(self.units,), initializer="random_normal", trainable=True
389
)
390
391
def call(self, inputs):
392
return ops.matmul(inputs, self.w) + self.b
393
394
def get_config(self):
395
return {"units": self.units}
396
397
398
# Now you can recreate the layer from its config:
399
layer = Linear(64)
400
config = layer.get_config()
401
print(config)
402
new_layer = Linear.from_config(config)
403
404
"""
405
Note that the `__init__()` method of the base `Layer` class takes some keyword
406
arguments, in particular a `name` and a `dtype`. It's good practice to pass
407
these arguments to the parent class in `__init__()` and to include them in the
408
layer config:
409
"""
410
411
412
class Linear(keras.layers.Layer):
413
def __init__(self, units=32, **kwargs):
414
super().__init__(**kwargs)
415
self.units = units
416
417
def build(self, input_shape):
418
self.w = self.add_weight(
419
shape=(input_shape[-1], self.units),
420
initializer="random_normal",
421
trainable=True,
422
)
423
self.b = self.add_weight(
424
shape=(self.units,), initializer="random_normal", trainable=True
425
)
426
427
def call(self, inputs):
428
return ops.matmul(inputs, self.w) + self.b
429
430
def get_config(self):
431
config = super().get_config()
432
config.update({"units": self.units})
433
return config
434
435
436
layer = Linear(64)
437
config = layer.get_config()
438
print(config)
439
new_layer = Linear.from_config(config)
440
441
"""
442
If you need more flexibility when deserializing the layer from its config, you
443
can also override the `from_config()` class method. This is the base
444
implementation of `from_config()`:
445
446
```python
447
def from_config(cls, config):
448
return cls(**config)
449
```
450
451
To learn more about serialization and saving, see the complete
452
[guide to saving and serializing models](/guides/serialization_and_saving/).
453
"""
454
455
"""
456
## Privileged `training` argument in the `call()` method
457
458
Some layers, in particular the `BatchNormalization` layer and the `Dropout`
459
layer, have different behaviors during training and inference. For such
460
layers, it is standard practice to expose a `training` (boolean) argument in
461
the `call()` method.
462
463
By exposing this argument in `call()`, you enable the built-in training and
464
evaluation loops (e.g. `fit()`) to correctly use the layer in training and
465
inference.
466
"""
467
468
469
class CustomDropout(keras.layers.Layer):
470
def __init__(self, rate, **kwargs):
471
super().__init__(**kwargs)
472
self.rate = rate
473
self.seed_generator = keras.random.SeedGenerator(1337)
474
475
def call(self, inputs, training=None):
476
if training:
477
return keras.random.dropout(
478
inputs, rate=self.rate, seed=self.seed_generator
479
)
480
return inputs
481
482
483
"""
484
## Privileged `mask` argument in the `call()` method
485
486
The other privileged argument supported by `call()` is the `mask` argument.
487
488
You will find it in all Keras RNN layers. A mask is a boolean tensor (one
489
boolean value per timestep in the input) used to skip certain input timesteps
490
when processing timeseries data.
491
492
Keras will automatically pass the correct `mask` argument to `__call__()` for
493
layers that support it, when a mask is generated by a prior layer.
494
Mask-generating layers are the `Embedding`
495
layer configured with `mask_zero=True`, and the `Masking` layer.
496
"""
497
498
"""
499
## The `Model` class
500
501
In general, you will use the `Layer` class to define inner computation blocks,
502
and will use the `Model` class to define the outer model -- the object you
503
will train.
504
505
For instance, in a ResNet50 model, you would have several ResNet blocks
506
subclassing `Layer`, and a single `Model` encompassing the entire ResNet50
507
network.
508
509
The `Model` class has the same API as `Layer`, with the following differences:
510
511
- It exposes built-in training, evaluation, and prediction loops
512
(`model.fit()`, `model.evaluate()`, `model.predict()`).
513
- It exposes the list of its inner layers, via the `model.layers` property.
514
- It exposes saving and serialization APIs (`save()`, `save_weights()`...)
515
516
Effectively, the `Layer` class corresponds to what we refer to in the
517
literature as a "layer" (as in "convolution layer" or "recurrent layer") or as
518
a "block" (as in "ResNet block" or "Inception block").
519
520
Meanwhile, the `Model` class corresponds to what is referred to in the
521
literature as a "model" (as in "deep learning model") or as a "network" (as in
522
"deep neural network").
523
524
So if you're wondering, "should I use the `Layer` class or the `Model` class?",
525
ask yourself: will I need to call `fit()` on it? Will I need to call `save()`
526
on it? If so, go with `Model`. If not (either because your class is just a block
527
in a bigger system, or because you are writing training & saving code yourself),
528
use `Layer`.
529
530
For instance, we could take our mini-resnet example above, and use it to build
531
a `Model` that we could train with `fit()`, and that we could save with
532
`save_weights()`:
533
"""
534
535
"""
536
```python
537
class ResNet(keras.Model):
538
539
def __init__(self, num_classes=1000):
540
super().__init__()
541
self.block_1 = ResNetBlock()
542
self.block_2 = ResNetBlock()
543
self.global_pool = layers.GlobalAveragePooling2D()
544
self.classifier = Dense(num_classes)
545
546
def call(self, inputs):
547
x = self.block_1(inputs)
548
x = self.block_2(x)
549
x = self.global_pool(x)
550
return self.classifier(x)
551
552
553
resnet = ResNet()
554
dataset = ...
555
resnet.fit(dataset, epochs=10)
556
resnet.save(filepath.keras)
557
```
558
"""
559
560
"""
561
## Putting it all together: an end-to-end example
562
563
Here's what you've learned so far:
564
565
- A `Layer` encapsulate a state (created in `__init__()` or `build()`) and some
566
computation (defined in `call()`).
567
- Layers can be recursively nested to create new, bigger computation blocks.
568
- Layers are backend-agnostic as long as they only use Keras APIs. You can use
569
backend-native APIs (such as `jax.numpy`, `torch.nn` or `tf.nn`), but then
570
your layer will only be usable with that specific backend.
571
- Layers can create and track losses (typically regularization losses)
572
via `add_loss()`.
573
- The outer container, the thing you want to train, is a `Model`. A `Model` is
574
just like a `Layer`, but with added training and serialization utilities.
575
576
Let's put all of these things together into an end-to-end example: we're going
577
to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion
578
-- so that it runs the same with TensorFlow, JAX, and PyTorch.
579
We'll train it on MNIST digits.
580
581
Our VAE will be a subclass of `Model`, built as a nested composition of layers
582
that subclass `Layer`. It will feature a regularization loss (KL divergence).
583
"""
584
585
586
class Sampling(layers.Layer):
587
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
588
589
def __init__(self, **kwargs):
590
super().__init__(**kwargs)
591
self.seed_generator = keras.random.SeedGenerator(1337)
592
593
def call(self, inputs):
594
z_mean, z_log_var = inputs
595
batch = ops.shape(z_mean)[0]
596
dim = ops.shape(z_mean)[1]
597
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
598
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
599
600
601
class Encoder(layers.Layer):
602
"""Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""
603
604
def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
605
super().__init__(name=name, **kwargs)
606
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
607
self.dense_mean = layers.Dense(latent_dim)
608
self.dense_log_var = layers.Dense(latent_dim)
609
self.sampling = Sampling()
610
611
def call(self, inputs):
612
x = self.dense_proj(inputs)
613
z_mean = self.dense_mean(x)
614
z_log_var = self.dense_log_var(x)
615
z = self.sampling((z_mean, z_log_var))
616
return z_mean, z_log_var, z
617
618
619
class Decoder(layers.Layer):
620
"""Converts z, the encoded digit vector, back into a readable digit."""
621
622
def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
623
super().__init__(name=name, **kwargs)
624
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
625
self.dense_output = layers.Dense(original_dim, activation="sigmoid")
626
627
def call(self, inputs):
628
x = self.dense_proj(inputs)
629
return self.dense_output(x)
630
631
632
class VariationalAutoEncoder(keras.Model):
633
"""Combines the encoder and decoder into an end-to-end model for training."""
634
635
def __init__(
636
self,
637
original_dim,
638
intermediate_dim=64,
639
latent_dim=32,
640
name="autoencoder",
641
**kwargs
642
):
643
super().__init__(name=name, **kwargs)
644
self.original_dim = original_dim
645
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
646
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
647
648
def call(self, inputs):
649
z_mean, z_log_var, z = self.encoder(inputs)
650
reconstructed = self.decoder(z)
651
# Add KL divergence regularization loss.
652
kl_loss = -0.5 * ops.mean(
653
z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1
654
)
655
self.add_loss(kl_loss)
656
return reconstructed
657
658
659
"""
660
Let's train it on MNIST using the `fit()` API:
661
"""
662
663
(x_train, _), _ = keras.datasets.mnist.load_data()
664
x_train = x_train.reshape(60000, 784).astype("float32") / 255
665
666
original_dim = 784
667
vae = VariationalAutoEncoder(784, 64, 32)
668
669
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
670
vae.compile(optimizer, loss=keras.losses.MeanSquaredError())
671
672
vae.fit(x_train, x_train, epochs=2, batch_size=64)
673
674