Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/functional_api.py
3273 views
1
"""
2
Title: The Functional API
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/03/01
5
Last modified: 2023/06/25
6
Description: Complete guide to the functional API.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import numpy as np
15
import keras
16
from keras import layers
17
from keras import ops
18
19
"""
20
## Introduction
21
22
The Keras *functional API* is a way to create models that are more flexible
23
than the `keras.Sequential` API. The functional API can handle models
24
with non-linear topology, shared layers, and even multiple inputs or outputs.
25
26
The main idea is that a deep learning model is usually
27
a directed acyclic graph (DAG) of layers.
28
So the functional API is a way to build *graphs of layers*.
29
30
Consider the following model:
31
32
<div class="k-default-codeblock">
33
```
34
(input: 784-dimensional vectors)
35
36
[Dense (64 units, relu activation)]
37
38
[Dense (64 units, relu activation)]
39
40
[Dense (10 units, softmax activation)]
41
42
(output: logits of a probability distribution over 10 classes)
43
```
44
</div>
45
46
This is a basic graph with three layers.
47
To build this model using the functional API, start by creating an input node:
48
"""
49
50
inputs = keras.Input(shape=(784,))
51
52
"""
53
The shape of the data is set as a 784-dimensional vector.
54
The batch size is always omitted since only the shape of each sample is specified.
55
56
If, for example, you have an image input with a shape of `(32, 32, 3)`,
57
you would use:
58
"""
59
60
# Just for demonstration purposes.
61
img_inputs = keras.Input(shape=(32, 32, 3))
62
63
"""
64
The `inputs` that is returned contains information about the shape and `dtype`
65
of the input data that you feed to your model.
66
Here's the shape:
67
"""
68
69
inputs.shape
70
71
"""
72
Here's the dtype:
73
"""
74
75
inputs.dtype
76
77
"""
78
You create a new node in the graph of layers by calling a layer on this `inputs`
79
object:
80
"""
81
82
dense = layers.Dense(64, activation="relu")
83
x = dense(inputs)
84
85
"""
86
The "layer call" action is like drawing an arrow from "inputs" to this layer
87
you created.
88
You're "passing" the inputs to the `dense` layer, and you get `x` as the output.
89
90
Let's add a few more layers to the graph of layers:
91
"""
92
93
x = layers.Dense(64, activation="relu")(x)
94
outputs = layers.Dense(10)(x)
95
96
"""
97
At this point, you can create a `Model` by specifying its inputs and outputs
98
in the graph of layers:
99
"""
100
101
model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
102
103
"""
104
Let's check out what the model summary looks like:
105
"""
106
107
model.summary()
108
109
"""
110
You can also plot the model as a graph:
111
"""
112
113
keras.utils.plot_model(model, "my_first_model.png")
114
115
"""
116
And, optionally, display the input and output shapes of each layer
117
in the plotted graph:
118
"""
119
120
keras.utils.plot_model(model, "my_first_model_with_shape_info.png", show_shapes=True)
121
122
"""
123
This figure and the code are almost identical. In the code version,
124
the connection arrows are replaced by the call operation.
125
126
A "graph of layers" is an intuitive mental image for a deep learning model,
127
and the functional API is a way to create models that closely mirrors this.
128
"""
129
130
"""
131
## Training, evaluation, and inference
132
133
Training, evaluation, and inference work exactly in the same way for models
134
built using the functional API as for `Sequential` models.
135
136
The `Model` class offers a built-in training loop (the `fit()` method)
137
and a built-in evaluation loop (the `evaluate()` method). Note
138
that you can easily customize these loops to implement your own training routines.
139
See also the guides on customizing what happens in `fit()`:
140
141
- [Writing a custom train step with TensorFlow](/guides/custom_train_step_in_tensorflow/)
142
- [Writing a custom train step with JAX](/guides/custom_train_step_in_jax/)
143
- [Writing a custom train step with PyTorch](/guides/custom_train_step_in_torch/)
144
145
Here, load the MNIST image data, reshape it into vectors,
146
fit the model on the data (while monitoring performance on a validation split),
147
then evaluate the model on the test data:
148
"""
149
150
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
151
152
x_train = x_train.reshape(60000, 784).astype("float32") / 255
153
x_test = x_test.reshape(10000, 784).astype("float32") / 255
154
155
model.compile(
156
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
157
optimizer=keras.optimizers.RMSprop(),
158
metrics=["accuracy"],
159
)
160
161
history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2)
162
163
test_scores = model.evaluate(x_test, y_test, verbose=2)
164
print("Test loss:", test_scores[0])
165
print("Test accuracy:", test_scores[1])
166
167
"""
168
For further reading, see the
169
[training and evaluation](/guides/training_with_built_in_methods/) guide.
170
"""
171
172
"""
173
## Save and serialize
174
175
Saving the model and serialization work the same way for models built using
176
the functional API as they do for `Sequential` models. The standard way
177
to save a functional model is to call `model.save()`
178
to save the entire model as a single file. You can later recreate the same model
179
from this file, even if the code that built the model is no longer available.
180
181
This saved file includes the:
182
- model architecture
183
- model weight values (that were learned during training)
184
- model training config, if any (as passed to `compile()`)
185
- optimizer and its state, if any (to restart training where you left off)
186
"""
187
188
model.save("my_model.keras")
189
del model
190
# Recreate the exact same model purely from the file:
191
model = keras.models.load_model("my_model.keras")
192
193
"""
194
For details, read the model [serialization & saving](/guides/serialization_and_saving/) guide.
195
"""
196
197
"""
198
## Use the same graph of layers to define multiple models
199
200
In the functional API, models are created by specifying their inputs
201
and outputs in a graph of layers. That means that a single
202
graph of layers can be used to generate multiple models.
203
204
In the example below, you use the same stack of layers to instantiate two models:
205
an `encoder` model that turns image inputs into 16-dimensional vectors,
206
and an end-to-end `autoencoder` model for training.
207
"""
208
209
encoder_input = keras.Input(shape=(28, 28, 1), name="img")
210
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
211
x = layers.Conv2D(32, 3, activation="relu")(x)
212
x = layers.MaxPooling2D(3)(x)
213
x = layers.Conv2D(32, 3, activation="relu")(x)
214
x = layers.Conv2D(16, 3, activation="relu")(x)
215
encoder_output = layers.GlobalMaxPooling2D()(x)
216
217
encoder = keras.Model(encoder_input, encoder_output, name="encoder")
218
encoder.summary()
219
220
x = layers.Reshape((4, 4, 1))(encoder_output)
221
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
222
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
223
x = layers.UpSampling2D(3)(x)
224
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
225
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)
226
227
autoencoder = keras.Model(encoder_input, decoder_output, name="autoencoder")
228
autoencoder.summary()
229
230
"""
231
Here, the decoding architecture is strictly symmetrical
232
to the encoding architecture, so the output shape is the same as
233
the input shape `(28, 28, 1)`.
234
235
The reverse of a `Conv2D` layer is a `Conv2DTranspose` layer,
236
and the reverse of a `MaxPooling2D` layer is an `UpSampling2D` layer.
237
"""
238
239
"""
240
## All models are callable, just like layers
241
242
You can treat any model as if it were a layer by invoking it on an `Input` or
243
on the output of another layer. By calling a model you aren't just reusing
244
the architecture of the model, you're also reusing its weights.
245
246
To see this in action, here's a different take on the autoencoder example that
247
creates an encoder model, a decoder model, and chains them in two calls
248
to obtain the autoencoder model:
249
"""
250
251
encoder_input = keras.Input(shape=(28, 28, 1), name="original_img")
252
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
253
x = layers.Conv2D(32, 3, activation="relu")(x)
254
x = layers.MaxPooling2D(3)(x)
255
x = layers.Conv2D(32, 3, activation="relu")(x)
256
x = layers.Conv2D(16, 3, activation="relu")(x)
257
encoder_output = layers.GlobalMaxPooling2D()(x)
258
259
encoder = keras.Model(encoder_input, encoder_output, name="encoder")
260
encoder.summary()
261
262
decoder_input = keras.Input(shape=(16,), name="encoded_img")
263
x = layers.Reshape((4, 4, 1))(decoder_input)
264
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
265
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
266
x = layers.UpSampling2D(3)(x)
267
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
268
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)
269
270
decoder = keras.Model(decoder_input, decoder_output, name="decoder")
271
decoder.summary()
272
273
autoencoder_input = keras.Input(shape=(28, 28, 1), name="img")
274
encoded_img = encoder(autoencoder_input)
275
decoded_img = decoder(encoded_img)
276
autoencoder = keras.Model(autoencoder_input, decoded_img, name="autoencoder")
277
autoencoder.summary()
278
279
"""
280
As you can see, the model can be nested: a model can contain sub-models
281
(since a model is just like a layer).
282
A common use case for model nesting is *ensembling*.
283
For example, here's how to ensemble a set of models into a single model
284
that averages their predictions:
285
"""
286
287
288
def get_model():
289
inputs = keras.Input(shape=(128,))
290
outputs = layers.Dense(1)(inputs)
291
return keras.Model(inputs, outputs)
292
293
294
model1 = get_model()
295
model2 = get_model()
296
model3 = get_model()
297
298
inputs = keras.Input(shape=(128,))
299
y1 = model1(inputs)
300
y2 = model2(inputs)
301
y3 = model3(inputs)
302
outputs = layers.average([y1, y2, y3])
303
ensemble_model = keras.Model(inputs=inputs, outputs=outputs)
304
305
"""
306
## Manipulate complex graph topologies
307
308
### Models with multiple inputs and outputs
309
310
The functional API makes it easy to manipulate multiple inputs and outputs.
311
This cannot be handled with the `Sequential` API.
312
313
For example, if you're building a system for ranking customer issue tickets by
314
priority and routing them to the correct department,
315
then the model will have three inputs:
316
317
- the title of the ticket (text input),
318
- the text body of the ticket (text input), and
319
- any tags added by the user (categorical input)
320
321
This model will have two outputs:
322
323
- the priority score between 0 and 1 (scalar sigmoid output), and
324
- the department that should handle the ticket (softmax output
325
over the set of departments).
326
327
You can build this model in a few lines with the functional API:
328
"""
329
330
num_tags = 12 # Number of unique issue tags
331
num_words = 10000 # Size of vocabulary obtained when preprocessing text data
332
num_departments = 4 # Number of departments for predictions
333
334
title_input = keras.Input(
335
shape=(None,), name="title"
336
) # Variable-length sequence of ints
337
body_input = keras.Input(shape=(None,), name="body") # Variable-length sequence of ints
338
tags_input = keras.Input(
339
shape=(num_tags,), name="tags"
340
) # Binary vectors of size `num_tags`
341
342
# Embed each word in the title into a 64-dimensional vector
343
title_features = layers.Embedding(num_words, 64)(title_input)
344
# Embed each word in the text into a 64-dimensional vector
345
body_features = layers.Embedding(num_words, 64)(body_input)
346
347
# Reduce sequence of embedded words in the title into a single 128-dimensional vector
348
title_features = layers.LSTM(128)(title_features)
349
# Reduce sequence of embedded words in the body into a single 32-dimensional vector
350
body_features = layers.LSTM(32)(body_features)
351
352
# Merge all available features into a single large vector via concatenation
353
x = layers.concatenate([title_features, body_features, tags_input])
354
355
# Stick a logistic regression for priority prediction on top of the features
356
priority_pred = layers.Dense(1, name="priority")(x)
357
# Stick a department classifier on top of the features
358
department_pred = layers.Dense(num_departments, name="department")(x)
359
360
# Instantiate an end-to-end model predicting both priority and department
361
model = keras.Model(
362
inputs=[title_input, body_input, tags_input],
363
outputs={"priority": priority_pred, "department": department_pred},
364
)
365
366
"""
367
Now plot the model:
368
"""
369
370
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
371
372
"""
373
When compiling this model, you can assign different losses to each output.
374
You can even assign different weights to each loss -- to modulate
375
their contribution to the total training loss.
376
"""
377
378
model.compile(
379
optimizer=keras.optimizers.RMSprop(1e-3),
380
loss=[
381
keras.losses.BinaryCrossentropy(from_logits=True),
382
keras.losses.CategoricalCrossentropy(from_logits=True),
383
],
384
loss_weights=[1.0, 0.2],
385
)
386
387
"""
388
Since the output layers have different names, you could also specify
389
the losses and loss weights with the corresponding layer names:
390
"""
391
392
model.compile(
393
optimizer=keras.optimizers.RMSprop(1e-3),
394
loss={
395
"priority": keras.losses.BinaryCrossentropy(from_logits=True),
396
"department": keras.losses.CategoricalCrossentropy(from_logits=True),
397
},
398
loss_weights={"priority": 1.0, "department": 0.2},
399
)
400
401
"""
402
Train the model by passing lists of NumPy arrays of inputs and targets:
403
"""
404
405
# Dummy input data
406
title_data = np.random.randint(num_words, size=(1280, 12))
407
body_data = np.random.randint(num_words, size=(1280, 100))
408
tags_data = np.random.randint(2, size=(1280, num_tags)).astype("float32")
409
410
# Dummy target data
411
priority_targets = np.random.random(size=(1280, 1))
412
dept_targets = np.random.randint(2, size=(1280, num_departments))
413
414
model.fit(
415
{"title": title_data, "body": body_data, "tags": tags_data},
416
{"priority": priority_targets, "department": dept_targets},
417
epochs=2,
418
batch_size=32,
419
)
420
421
"""
422
When calling fit with a `Dataset` object, it should yield either a
423
tuple of lists like `([title_data, body_data, tags_data], [priority_targets, dept_targets])`
424
or a tuple of dictionaries like
425
`({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})`.
426
427
For more detailed explanation, refer to the
428
[training and evaluation](/guides/training_with_built_in_methods/) guide.
429
"""
430
431
"""
432
### A toy ResNet model
433
434
In addition to models with multiple inputs and outputs,
435
the functional API makes it easy to manipulate non-linear connectivity
436
topologies -- these are models with layers that are not connected sequentially,
437
which the `Sequential` API cannot handle.
438
439
A common use case for this is residual connections.
440
Let's build a toy ResNet model for CIFAR10 to demonstrate this:
441
"""
442
443
inputs = keras.Input(shape=(32, 32, 3), name="img")
444
x = layers.Conv2D(32, 3, activation="relu")(inputs)
445
x = layers.Conv2D(64, 3, activation="relu")(x)
446
block_1_output = layers.MaxPooling2D(3)(x)
447
448
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
449
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
450
block_2_output = layers.add([x, block_1_output])
451
452
x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
453
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
454
block_3_output = layers.add([x, block_2_output])
455
456
x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
457
x = layers.GlobalAveragePooling2D()(x)
458
x = layers.Dense(256, activation="relu")(x)
459
x = layers.Dropout(0.5)(x)
460
outputs = layers.Dense(10)(x)
461
462
model = keras.Model(inputs, outputs, name="toy_resnet")
463
model.summary()
464
465
"""
466
Plot the model:
467
"""
468
469
keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True)
470
471
"""
472
Now train the model:
473
"""
474
475
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
476
477
x_train = x_train.astype("float32") / 255.0
478
x_test = x_test.astype("float32") / 255.0
479
y_train = keras.utils.to_categorical(y_train, 10)
480
y_test = keras.utils.to_categorical(y_test, 10)
481
482
model.compile(
483
optimizer=keras.optimizers.RMSprop(1e-3),
484
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
485
metrics=["acc"],
486
)
487
# We restrict the data to the first 1000 samples so as to limit execution time
488
# on Colab. Try to train on the entire dataset until convergence!
489
model.fit(
490
x_train[:1000],
491
y_train[:1000],
492
batch_size=64,
493
epochs=1,
494
validation_split=0.2,
495
)
496
497
"""
498
## Shared layers
499
500
Another good use for the functional API are models that use *shared layers*.
501
Shared layers are layer instances that are reused multiple times in the same model --
502
they learn features that correspond to multiple paths in the graph-of-layers.
503
504
Shared layers are often used to encode inputs from similar spaces
505
(say, two different pieces of text that feature similar vocabulary).
506
They enable sharing of information across these different inputs,
507
and they make it possible to train such a model on less data.
508
If a given word is seen in one of the inputs,
509
that will benefit the processing of all inputs that pass through the shared layer.
510
511
To share a layer in the functional API, call the same layer instance multiple times.
512
For instance, here's an `Embedding` layer shared across two different text inputs:
513
"""
514
515
# Embedding for 1000 unique words mapped to 128-dimensional vectors
516
shared_embedding = layers.Embedding(1000, 128)
517
518
# Variable-length sequence of integers
519
text_input_a = keras.Input(shape=(None,), dtype="int32")
520
521
# Variable-length sequence of integers
522
text_input_b = keras.Input(shape=(None,), dtype="int32")
523
524
# Reuse the same layer to encode both inputs
525
encoded_input_a = shared_embedding(text_input_a)
526
encoded_input_b = shared_embedding(text_input_b)
527
528
"""
529
## Extract and reuse nodes in the graph of layers
530
531
Because the graph of layers you are manipulating is a static data structure,
532
it can be accessed and inspected. And this is how you are able to plot
533
functional models as images.
534
535
This also means that you can access the activations of intermediate layers
536
("nodes" in the graph) and reuse them elsewhere --
537
which is very useful for something like feature extraction.
538
539
Let's look at an example. This is a VGG19 model with weights pretrained on ImageNet:
540
"""
541
542
vgg19 = keras.applications.VGG19()
543
544
"""
545
And these are the intermediate activations of the model,
546
obtained by querying the graph data structure:
547
"""
548
549
features_list = [layer.output for layer in vgg19.layers]
550
551
"""
552
Use these features to create a new feature-extraction model that returns
553
the values of the intermediate layer activations:
554
"""
555
556
feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)
557
558
img = np.random.random((1, 224, 224, 3)).astype("float32")
559
extracted_features = feat_extraction_model(img)
560
561
"""
562
This comes in handy for tasks like
563
[neural style transfer](https://keras.io/examples/generative/neural_style_transfer/),
564
among other things.
565
"""
566
567
"""
568
## Extend the API using custom layers
569
570
`keras` includes a wide range of built-in layers, for example:
571
572
- Convolutional layers: `Conv1D`, `Conv2D`, `Conv3D`, `Conv2DTranspose`
573
- Pooling layers: `MaxPooling1D`, `MaxPooling2D`, `MaxPooling3D`, `AveragePooling1D`
574
- RNN layers: `GRU`, `LSTM`, `ConvLSTM2D`
575
- `BatchNormalization`, `Dropout`, `Embedding`, etc.
576
577
But if you don't find what you need, it's easy to extend the API by creating
578
your own layers. All layers subclass the `Layer` class and implement:
579
580
- `call` method, that specifies the computation done by the layer.
581
- `build` method, that creates the weights of the layer (this is just a style
582
convention since you can create weights in `__init__`, as well).
583
584
To learn more about creating layers from scratch, read
585
[custom layers and models](/guides/making_new_layers_and_models_via_subclassing) guide.
586
587
The following is a basic implementation of `keras.layers.Dense`:
588
"""
589
590
591
class CustomDense(layers.Layer):
592
def __init__(self, units=32):
593
super().__init__()
594
self.units = units
595
596
def build(self, input_shape):
597
self.w = self.add_weight(
598
shape=(input_shape[-1], self.units),
599
initializer="random_normal",
600
trainable=True,
601
)
602
self.b = self.add_weight(
603
shape=(self.units,), initializer="random_normal", trainable=True
604
)
605
606
def call(self, inputs):
607
return ops.matmul(inputs, self.w) + self.b
608
609
610
inputs = keras.Input((4,))
611
outputs = CustomDense(10)(inputs)
612
613
model = keras.Model(inputs, outputs)
614
615
"""
616
For serialization support in your custom layer, define a `get_config()`
617
method that returns the constructor arguments of the layer instance:
618
"""
619
620
621
class CustomDense(layers.Layer):
622
def __init__(self, units=32):
623
super().__init__()
624
self.units = units
625
626
def build(self, input_shape):
627
self.w = self.add_weight(
628
shape=(input_shape[-1], self.units),
629
initializer="random_normal",
630
trainable=True,
631
)
632
self.b = self.add_weight(
633
shape=(self.units,), initializer="random_normal", trainable=True
634
)
635
636
def call(self, inputs):
637
return ops.matmul(inputs, self.w) + self.b
638
639
def get_config(self):
640
return {"units": self.units}
641
642
643
inputs = keras.Input((4,))
644
outputs = CustomDense(10)(inputs)
645
646
model = keras.Model(inputs, outputs)
647
config = model.get_config()
648
649
new_model = keras.Model.from_config(config, custom_objects={"CustomDense": CustomDense})
650
651
"""
652
Optionally, implement the class method `from_config(cls, config)` which is used
653
when recreating a layer instance given its config dictionary.
654
The default implementation of `from_config` is:
655
656
```python
657
def from_config(cls, config):
658
return cls(**config)
659
```
660
"""
661
662
"""
663
## When to use the functional API
664
665
Should you use the Keras functional API to create a new model,
666
or just subclass the `Model` class directly? In general, the functional API
667
is higher-level, easier and safer, and has a number of
668
features that subclassed models do not support.
669
670
However, model subclassing provides greater flexibility when building models
671
that are not easily expressible as directed acyclic graphs of layers.
672
For example, you could not implement a Tree-RNN with the functional API
673
and would have to subclass `Model` directly.
674
675
For an in-depth look at the differences between the functional API and
676
model subclassing, read
677
[What are Symbolic and Imperative APIs in TensorFlow 2.0?](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html).
678
679
### Functional API strengths:
680
681
The following properties are also true for Sequential models
682
(which are also data structures), but are not true for subclassed models
683
(which are Python bytecode, not data structures).
684
685
#### Less verbose
686
687
There is no `super().__init__(...)`, no `def call(self, ...):`, etc.
688
689
Compare:
690
691
```python
692
inputs = keras.Input(shape=(32,))
693
x = layers.Dense(64, activation='relu')(inputs)
694
outputs = layers.Dense(10)(x)
695
mlp = keras.Model(inputs, outputs)
696
```
697
698
With the subclassed version:
699
700
```python
701
class MLP(keras.Model):
702
703
def __init__(self, **kwargs):
704
super().__init__(**kwargs)
705
self.dense_1 = layers.Dense(64, activation='relu')
706
self.dense_2 = layers.Dense(10)
707
708
def call(self, inputs):
709
x = self.dense_1(inputs)
710
return self.dense_2(x)
711
712
# Instantiate the model.
713
mlp = MLP()
714
# Necessary to create the model's state.
715
# The model doesn't have a state until it's called at least once.
716
_ = mlp(ops.zeros((1, 32)))
717
```
718
719
#### Model validation while defining its connectivity graph
720
721
In the functional API, the input specification (shape and dtype) is created
722
in advance (using `Input`). Every time you call a layer,
723
the layer checks that the specification passed to it matches its assumptions,
724
and it will raise a helpful error message if not.
725
726
This guarantees that any model you can build with the functional API will run.
727
All debugging -- other than convergence-related debugging --
728
happens statically during the model construction and not at execution time.
729
This is similar to type checking in a compiler.
730
731
#### A functional model is plottable and inspectable
732
733
You can plot the model as a graph, and you can easily access intermediate nodes
734
in this graph. For example, to extract and reuse the activations of intermediate
735
layers (as seen in a previous example):
736
737
```python
738
features_list = [layer.output for layer in vgg19.layers]
739
feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)
740
```
741
742
#### A functional model can be serialized or cloned
743
744
Because a functional model is a data structure rather than a piece of code,
745
it is safely serializable and can be saved as a single file
746
that allows you to recreate the exact same model
747
without having access to any of the original code.
748
See the [serialization & saving guide](/guides/serialization_and_saving/).
749
750
To serialize a subclassed model, it is necessary for the implementer
751
to specify a `get_config()`
752
and `from_config()` method at the model level.
753
754
755
### Functional API weakness:
756
757
#### It does not support dynamic architectures
758
759
The functional API treats models as DAGs of layers.
760
This is true for most deep learning architectures, but not all -- for example,
761
recursive networks or Tree RNNs do not follow this assumption and cannot
762
be implemented in the functional API.
763
"""
764
765
"""
766
## Mix-and-match API styles
767
768
Choosing between the functional API or Model subclassing isn't a
769
binary decision that restricts you into one category of models.
770
All models in the `keras` API can interact with each other, whether they're
771
`Sequential` models, functional models, or subclassed models that are written
772
from scratch.
773
774
You can always use a functional model or `Sequential` model
775
as part of a subclassed model or layer:
776
"""
777
778
units = 32
779
timesteps = 10
780
input_dim = 5
781
782
# Define a Functional model
783
inputs = keras.Input((None, units))
784
x = layers.GlobalAveragePooling1D()(inputs)
785
outputs = layers.Dense(1)(x)
786
model = keras.Model(inputs, outputs)
787
788
789
class CustomRNN(layers.Layer):
790
def __init__(self):
791
super().__init__()
792
self.units = units
793
self.projection_1 = layers.Dense(units=units, activation="tanh")
794
self.projection_2 = layers.Dense(units=units, activation="tanh")
795
# Our previously-defined Functional model
796
self.classifier = model
797
798
def call(self, inputs):
799
outputs = []
800
state = ops.zeros(shape=(inputs.shape[0], self.units))
801
for t in range(inputs.shape[1]):
802
x = inputs[:, t, :]
803
h = self.projection_1(x)
804
y = h + self.projection_2(state)
805
state = y
806
outputs.append(y)
807
features = ops.stack(outputs, axis=1)
808
print(features.shape)
809
return self.classifier(features)
810
811
812
rnn_model = CustomRNN()
813
_ = rnn_model(ops.zeros((1, timesteps, input_dim)))
814
815
"""
816
You can use any subclassed layer or model in the functional API
817
as long as it implements a `call` method that follows one of the following patterns:
818
819
- `call(self, inputs, **kwargs)` --
820
Where `inputs` is a tensor or a nested structure of tensors (e.g. a list of tensors),
821
and where `**kwargs` are non-tensor arguments (non-inputs).
822
- `call(self, inputs, training=None, **kwargs)` --
823
Where `training` is a boolean indicating whether the layer should behave
824
in training mode and inference mode.
825
- `call(self, inputs, mask=None, **kwargs)` --
826
Where `mask` is a boolean mask tensor (useful for RNNs, for instance).
827
- `call(self, inputs, training=None, mask=None, **kwargs)` --
828
Of course, you can have both masking and training-specific behavior at the same time.
829
830
Additionally, if you implement the `get_config` method on your custom Layer or model,
831
the functional models you create will still be serializable and cloneable.
832
833
Here's a quick example of a custom RNN, written from scratch,
834
being used in a functional model:
835
"""
836
837
units = 32
838
timesteps = 10
839
input_dim = 5
840
batch_size = 16
841
842
843
class CustomRNN(layers.Layer):
844
def __init__(self):
845
super().__init__()
846
self.units = units
847
self.projection_1 = layers.Dense(units=units, activation="tanh")
848
self.projection_2 = layers.Dense(units=units, activation="tanh")
849
self.classifier = layers.Dense(1)
850
851
def call(self, inputs):
852
outputs = []
853
state = ops.zeros(shape=(inputs.shape[0], self.units))
854
for t in range(inputs.shape[1]):
855
x = inputs[:, t, :]
856
h = self.projection_1(x)
857
y = h + self.projection_2(state)
858
state = y
859
outputs.append(y)
860
features = ops.stack(outputs, axis=1)
861
return self.classifier(features)
862
863
864
# Note that you specify a static batch size for the inputs with the `batch_shape`
865
# arg, because the inner computation of `CustomRNN` requires a static batch size
866
# (when you create the `state` zeros tensor).
867
inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))
868
x = layers.Conv1D(32, 3)(inputs)
869
outputs = CustomRNN()(x)
870
871
model = keras.Model(inputs, outputs)
872
873
rnn_model = CustomRNN()
874
_ = rnn_model(ops.zeros((1, 10, 5)))
875
876