Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/transfer_learning.py
3273 views
1
"""
2
Title: Transfer learning & fine-tuning
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/15
5
Last modified: 2023/06/25
6
Description: Complete guide to transfer learning & fine-tuning in Keras.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import numpy as np
15
import keras
16
from keras import layers
17
import tensorflow_datasets as tfds
18
import matplotlib.pyplot as plt
19
20
"""
21
## Introduction
22
23
**Transfer learning** consists of taking features learned on one problem, and
24
leveraging them on a new, similar problem. For instance, features from a model that has
25
learned to identify racoons may be useful to kick-start a model meant to identify
26
tanukis.
27
28
Transfer learning is usually done for tasks where your dataset has too little data to
29
train a full-scale model from scratch.
30
31
The most common incarnation of transfer learning in the context of deep learning is the
32
following workflow:
33
34
1. Take layers from a previously trained model.
35
2. Freeze them, so as to avoid destroying any of the information they contain during
36
future training rounds.
37
3. Add some new, trainable layers on top of the frozen layers. They will learn to turn
38
the old features into predictions on a new dataset.
39
4. Train the new layers on your dataset.
40
41
A last, optional step, is **fine-tuning**, which consists of unfreezing the entire
42
model you obtained above (or part of it), and re-training it on the new data with a
43
very low learning rate. This can potentially achieve meaningful improvements, by
44
incrementally adapting the pretrained features to the new data.
45
46
First, we will go over the Keras `trainable` API in detail, which underlies most
47
transfer learning & fine-tuning workflows.
48
49
Then, we'll demonstrate the typical workflow by taking a model pretrained on the
50
ImageNet dataset, and retraining it on the Kaggle "cats vs dogs" classification
51
dataset.
52
53
This is adapted from
54
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)
55
and the 2016 blog post
56
["building powerful image classification models using very little data"](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html).
57
"""
58
59
"""
60
## Freezing layers: understanding the `trainable` attribute
61
62
Layers & models have three weight attributes:
63
64
- `weights` is the list of all weights variables of the layer.
65
- `trainable_weights` is the list of those that are meant to be updated (via gradient
66
descent) to minimize the loss during training.
67
- `non_trainable_weights` is the list of those that aren't meant to be trained.
68
Typically they are updated by the model during the forward pass.
69
70
**Example: the `Dense` layer has 2 trainable weights (kernel & bias)**
71
"""
72
73
layer = keras.layers.Dense(3)
74
layer.build((None, 4)) # Create the weights
75
76
print("weights:", len(layer.weights))
77
print("trainable_weights:", len(layer.trainable_weights))
78
print("non_trainable_weights:", len(layer.non_trainable_weights))
79
80
"""
81
In general, all weights are trainable weights. The only built-in layer that has
82
non-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights
83
to keep track of the mean and variance of its inputs during training.
84
To learn how to use non-trainable weights in your own custom layers, see the
85
[guide to writing new layers from scratch](/guides/making_new_layers_and_models_via_subclassing/).
86
87
**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable
88
weights**
89
"""
90
91
layer = keras.layers.BatchNormalization()
92
layer.build((None, 4)) # Create the weights
93
94
print("weights:", len(layer.weights))
95
print("trainable_weights:", len(layer.trainable_weights))
96
print("non_trainable_weights:", len(layer.non_trainable_weights))
97
98
"""
99
Layers & models also feature a boolean attribute `trainable`. Its value can be changed.
100
Setting `layer.trainable` to `False` moves all the layer's weights from trainable to
101
non-trainable. This is called "freezing" the layer: the state of a frozen layer won't
102
be updated during training (either when training with `fit()` or when training with
103
any custom loop that relies on `trainable_weights` to apply gradient updates).
104
105
**Example: setting `trainable` to `False`**
106
"""
107
108
layer = keras.layers.Dense(3)
109
layer.build((None, 4)) # Create the weights
110
layer.trainable = False # Freeze the layer
111
112
print("weights:", len(layer.weights))
113
print("trainable_weights:", len(layer.trainable_weights))
114
print("non_trainable_weights:", len(layer.non_trainable_weights))
115
116
"""
117
When a trainable weight becomes non-trainable, its value is no longer updated during
118
training.
119
"""
120
121
# Make a model with 2 layers
122
layer1 = keras.layers.Dense(3, activation="relu")
123
layer2 = keras.layers.Dense(3, activation="sigmoid")
124
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
125
126
# Freeze the first layer
127
layer1.trainable = False
128
129
# Keep a copy of the weights of layer1 for later reference
130
initial_layer1_weights_values = layer1.get_weights()
131
132
# Train the model
133
model.compile(optimizer="adam", loss="mse")
134
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
135
136
# Check that the weights of layer1 have not changed during training
137
final_layer1_weights_values = layer1.get_weights()
138
np.testing.assert_allclose(
139
initial_layer1_weights_values[0], final_layer1_weights_values[0]
140
)
141
np.testing.assert_allclose(
142
initial_layer1_weights_values[1], final_layer1_weights_values[1]
143
)
144
145
"""
146
Do not confuse the `layer.trainable` attribute with the argument `training` in
147
`layer.__call__()` (which controls whether the layer should run its forward pass in
148
inference mode or training mode). For more information, see the
149
[Keras FAQ](
150
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
151
"""
152
153
"""
154
## Recursive setting of the `trainable` attribute
155
156
If you set `trainable = False` on a model or on any layer that has sublayers,
157
all children layers become non-trainable as well.
158
159
**Example:**
160
"""
161
162
inner_model = keras.Sequential(
163
[
164
keras.Input(shape=(3,)),
165
keras.layers.Dense(3, activation="relu"),
166
keras.layers.Dense(3, activation="relu"),
167
]
168
)
169
170
model = keras.Sequential(
171
[
172
keras.Input(shape=(3,)),
173
inner_model,
174
keras.layers.Dense(3, activation="sigmoid"),
175
]
176
)
177
178
model.trainable = False # Freeze the outer model
179
180
assert inner_model.trainable == False # All layers in `model` are now frozen
181
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
182
183
"""
184
## The typical transfer-learning workflow
185
186
This leads us to how a typical transfer learning workflow can be implemented in Keras:
187
188
1. Instantiate a base model and load pre-trained weights into it.
189
2. Freeze all layers in the base model by setting `trainable = False`.
190
3. Create a new model on top of the output of one (or several) layers from the base
191
model.
192
4. Train your new model on your new dataset.
193
194
Note that an alternative, more lightweight workflow could also be:
195
196
1. Instantiate a base model and load pre-trained weights into it.
197
2. Run your new dataset through it and record the output of one (or several) layers
198
from the base model. This is called **feature extraction**.
199
3. Use that output as input data for a new, smaller model.
200
201
A key advantage of that second workflow is that you only run the base model once on
202
your data, rather than once per epoch of training. So it's a lot faster & cheaper.
203
204
An issue with that second workflow, though, is that it doesn't allow you to dynamically
205
modify the input data of your new model during training, which is required when doing
206
data augmentation, for instance. Transfer learning is typically used for tasks when
207
your new dataset has too little data to train a full-scale model from scratch, and in
208
such scenarios data augmentation is very important. So in what follows, we will focus
209
on the first workflow.
210
211
Here's what the first workflow looks like in Keras:
212
213
First, instantiate a base model with pre-trained weights.
214
215
```python
216
base_model = keras.applications.Xception(
217
weights='imagenet', # Load weights pre-trained on ImageNet.
218
input_shape=(150, 150, 3),
219
include_top=False) # Do not include the ImageNet classifier at the top.
220
```
221
222
Then, freeze the base model.
223
224
```python
225
base_model.trainable = False
226
```
227
228
Create a new model on top.
229
230
```python
231
inputs = keras.Input(shape=(150, 150, 3))
232
# We make sure that the base_model is running in inference mode here,
233
# by passing `training=False`. This is important for fine-tuning, as you will
234
# learn in a few paragraphs.
235
x = base_model(inputs, training=False)
236
# Convert features of shape `base_model.output_shape[1:]` to vectors
237
x = keras.layers.GlobalAveragePooling2D()(x)
238
# A Dense classifier with a single unit (binary classification)
239
outputs = keras.layers.Dense(1)(x)
240
model = keras.Model(inputs, outputs)
241
```
242
243
Train the model on new data.
244
245
```python
246
model.compile(optimizer=keras.optimizers.Adam(),
247
loss=keras.losses.BinaryCrossentropy(from_logits=True),
248
metrics=[keras.metrics.BinaryAccuracy()])
249
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
250
```
251
252
"""
253
254
"""
255
## Fine-tuning
256
257
Once your model has converged on the new data, you can try to unfreeze all or part of
258
the base model and retrain the whole model end-to-end with a very low learning rate.
259
260
This is an optional last step that can potentially give you incremental improvements.
261
It could also potentially lead to quick overfitting -- keep that in mind.
262
263
It is critical to only do this step *after* the model with frozen layers has been
264
trained to convergence. If you mix randomly-initialized trainable layers with
265
trainable layers that hold pre-trained features, the randomly-initialized layers will
266
cause very large gradient updates during training, which will destroy your pre-trained
267
features.
268
269
It's also critical to use a very low learning rate at this stage, because
270
you are training a much larger model than in the first round of training, on a dataset
271
that is typically very small.
272
As a result, you are at risk of overfitting very quickly if you apply large weight
273
updates. Here, you only want to readapt the pretrained weights in an incremental way.
274
275
This is how to implement fine-tuning of the whole base model:
276
277
```python
278
# Unfreeze the base model
279
base_model.trainable = True
280
281
# It's important to recompile your model after you make any changes
282
# to the `trainable` attribute of any inner layer, so that your changes
283
# are take into account
284
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
285
loss=keras.losses.BinaryCrossentropy(from_logits=True),
286
metrics=[keras.metrics.BinaryAccuracy()])
287
288
# Train end-to-end. Be careful to stop before you overfit!
289
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
290
```
291
292
**Important note about `compile()` and `trainable`**
293
294
Calling `compile()` on a model is meant to "freeze" the behavior of that model. This
295
implies that the `trainable`
296
attribute values at the time the model is compiled should be preserved throughout the
297
lifetime of that model,
298
until `compile` is called again. Hence, if you change any `trainable` value, make sure
299
to call `compile()` again on your
300
model for your changes to be taken into account.
301
302
**Important notes about `BatchNormalization` layer**
303
304
Many image models contain `BatchNormalization` layers. That layer is a special case on
305
every imaginable count. Here are a few things to keep in mind.
306
307
- `BatchNormalization` contains 2 non-trainable weights that get updated during
308
training. These are the variables tracking the mean and variance of the inputs.
309
- When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will
310
run in inference mode, and will not update its mean & variance statistics. This is not
311
the case for other layers in general, as
312
[weight trainability & inference/training modes are two orthogonal concepts](
313
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
314
But the two are tied in the case of the `BatchNormalization` layer.
315
- When you unfreeze a model that contains `BatchNormalization` layers in order to do
316
fine-tuning, you should keep the `BatchNormalization` layers in inference mode by
317
passing `training=False` when calling the base model.
318
Otherwise the updates applied to the non-trainable weights will suddenly destroy
319
what the model has learned.
320
321
You'll see this pattern in action in the end-to-end example at the end of this guide.
322
"""
323
324
"""
325
## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
326
327
To solidify these concepts, let's walk you through a concrete end-to-end transfer
328
learning & fine-tuning example. We will load the Xception model, pre-trained on
329
ImageNet, and use it on the Kaggle "cats vs. dogs" classification dataset.
330
"""
331
332
"""
333
### Getting the data
334
335
First, let's fetch the cats vs. dogs dataset using TFDS. If you have your own dataset,
336
you'll probably want to use the utility
337
`keras.utils.image_dataset_from_directory` to generate similar labeled
338
dataset objects from a set of images on disk filed into class-specific folders.
339
340
Transfer learning is most useful when working with very small datasets. To keep our
341
dataset small, we will use 40% of the original training data (25,000 images) for
342
training, 10% for validation, and 10% for testing.
343
"""
344
345
tfds.disable_progress_bar()
346
347
train_ds, validation_ds, test_ds = tfds.load(
348
"cats_vs_dogs",
349
# Reserve 10% for validation and 10% for test
350
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
351
as_supervised=True, # Include labels
352
)
353
354
print(f"Number of training samples: {train_ds.cardinality()}")
355
print(f"Number of validation samples: {validation_ds.cardinality()}")
356
print(f"Number of test samples: {test_ds.cardinality()}")
357
358
"""
359
These are the first 9 images in the training dataset -- as you can see, they're all
360
different sizes.
361
"""
362
363
plt.figure(figsize=(10, 10))
364
for i, (image, label) in enumerate(train_ds.take(9)):
365
ax = plt.subplot(3, 3, i + 1)
366
plt.imshow(image)
367
plt.title(int(label))
368
plt.axis("off")
369
370
"""
371
We can also see that label 1 is "dog" and label 0 is "cat".
372
"""
373
374
"""
375
### Standardizing the data
376
377
Our raw images have a variety of sizes. In addition, each pixel consists of 3 integer
378
values between 0 and 255 (RGB level values). This isn't a great fit for feeding a
379
neural network. We need to do 2 things:
380
381
- Standardize to a fixed image size. We pick 150x150.
382
- Normalize pixel values between -1 and 1. We'll do this using a `Normalization` layer as
383
part of the model itself.
384
385
In general, it's a good practice to develop models that take raw data as input, as
386
opposed to models that take already-preprocessed data. The reason being that, if your
387
model expects preprocessed data, any time you export your model to use it elsewhere
388
(in a web browser, in a mobile app), you'll need to reimplement the exact same
389
preprocessing pipeline. This gets very tricky very quickly. So we should do the least
390
possible amount of preprocessing before hitting the model.
391
392
Here, we'll do image resizing in the data pipeline (because a deep neural network can
393
only process contiguous batches of data), and we'll do the input value scaling as part
394
of the model, when we create it.
395
396
Let's resize images to 150x150:
397
"""
398
399
resize_fn = keras.layers.Resizing(150, 150)
400
401
train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
402
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
403
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
404
405
"""
406
### Using random data augmentation
407
408
When you don't have a large image dataset, it's a good practice to artificially
409
introduce sample diversity by applying random yet realistic transformations to
410
the training images, such as random horizontal flipping or small random rotations. This
411
helps expose the model to different aspects of the training data while slowing down
412
overfitting.
413
"""
414
415
augmentation_layers = [
416
layers.RandomFlip("horizontal"),
417
layers.RandomRotation(0.1),
418
]
419
420
421
def data_augmentation(x):
422
for layer in augmentation_layers:
423
x = layer(x)
424
return x
425
426
427
train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
428
429
"""
430
Let's batch the data and use prefetching to optimize loading speed.
431
"""
432
433
from tensorflow import data as tf_data
434
435
batch_size = 64
436
437
train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
438
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
439
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
440
441
"""
442
Let's visualize what the first image of the first batch looks like after various random
443
transformations:
444
"""
445
446
for images, labels in train_ds.take(1):
447
plt.figure(figsize=(10, 10))
448
first_image = images[0]
449
for i in range(9):
450
ax = plt.subplot(3, 3, i + 1)
451
augmented_image = data_augmentation(np.expand_dims(first_image, 0))
452
plt.imshow(np.array(augmented_image[0]).astype("int32"))
453
plt.title(int(labels[0]))
454
plt.axis("off")
455
456
"""
457
## Build a model
458
459
Now let's built a model that follows the blueprint we've explained earlier.
460
461
Note that:
462
463
- We add a `Rescaling` layer to scale input values (initially in the `[0, 255]`
464
range) to the `[-1, 1]` range.
465
- We add a `Dropout` layer before the classification layer, for regularization.
466
- We make sure to pass `training=False` when calling the base model, so that
467
it runs in inference mode, so that batchnorm statistics don't get updated
468
even after we unfreeze the base model for fine-tuning.
469
"""
470
471
base_model = keras.applications.Xception(
472
weights="imagenet", # Load weights pre-trained on ImageNet.
473
input_shape=(150, 150, 3),
474
include_top=False,
475
) # Do not include the ImageNet classifier at the top.
476
477
# Freeze the base_model
478
base_model.trainable = False
479
480
# Create new model on top
481
inputs = keras.Input(shape=(150, 150, 3))
482
483
# Pre-trained Xception weights requires that input be scaled
484
# from (0, 255) to a range of (-1., +1.), the rescaling layer
485
# outputs: `(inputs * scale) + offset`
486
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
487
x = scale_layer(inputs)
488
489
# The base model contains batchnorm layers. We want to keep them in inference mode
490
# when we unfreeze the base model for fine-tuning, so we make sure that the
491
# base_model is running in inference mode here.
492
x = base_model(x, training=False)
493
x = keras.layers.GlobalAveragePooling2D()(x)
494
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
495
outputs = keras.layers.Dense(1)(x)
496
model = keras.Model(inputs, outputs)
497
498
model.summary(show_trainable=True)
499
500
"""
501
## Train the top layer
502
"""
503
504
model.compile(
505
optimizer=keras.optimizers.Adam(),
506
loss=keras.losses.BinaryCrossentropy(from_logits=True),
507
metrics=[keras.metrics.BinaryAccuracy()],
508
)
509
510
epochs = 2
511
print("Fitting the top layer of the model")
512
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
513
514
"""
515
## Do a round of fine-tuning of the entire model
516
517
Finally, let's unfreeze the base model and train the entire model end-to-end with a low
518
learning rate.
519
520
Importantly, although the base model becomes trainable, it is still running in
521
inference mode since we passed `training=False` when calling it when we built the
522
model. This means that the batch normalization layers inside won't update their batch
523
statistics. If they did, they would wreck havoc on the representations learned by the
524
model so far.
525
"""
526
527
# Unfreeze the base_model. Note that it keeps running in inference mode
528
# since we passed `training=False` when calling it. This means that
529
# the batchnorm layers will not update their batch statistics.
530
# This prevents the batchnorm layers from undoing all the training
531
# we've done so far.
532
base_model.trainable = True
533
model.summary(show_trainable=True)
534
535
model.compile(
536
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
537
loss=keras.losses.BinaryCrossentropy(from_logits=True),
538
metrics=[keras.metrics.BinaryAccuracy()],
539
)
540
541
epochs = 1
542
print("Fitting the end-to-end model")
543
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
544
545
"""
546
After 10 epochs, fine-tuning gains us a nice improvement here.
547
Let's evaluate the model on the test dataset:
548
"""
549
550
print("Test dataset evaluation")
551
model.evaluate(test_ds)
552
553