Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/training_with_built_in_methods.py
3273 views
1
"""
2
Title: Training & evaluation with the built-in methods
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/03/01
5
Last modified: 2023/06/25
6
Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
# We import torch & TF so as to use torch Dataloaders & tf.data.Datasets.
15
import torch
16
import tensorflow as tf
17
18
import os
19
import numpy as np
20
import keras
21
from keras import layers
22
from keras import ops
23
24
"""
25
## Introduction
26
27
This guide covers training, evaluation, and prediction (inference) models
28
when using built-in APIs for training & validation (such as `Model.fit()`,
29
`Model.evaluate()` and `Model.predict()`).
30
31
If you are interested in leveraging `fit()` while specifying your
32
own training step function, see the guides on customizing what happens in `fit()`:
33
34
- [Writing a custom train step with TensorFlow](/guides/custom_train_step_in_tensorflow/)
35
- [Writing a custom train step with JAX](/guides/custom_train_step_in_jax/)
36
- [Writing a custom train step with PyTorch](/guides/custom_train_step_in_torch/)
37
38
If you are interested in writing your own training & evaluation loops from
39
scratch, see the guides on writing training loops:
40
41
- [Writing a training loop with TensorFlow](/guides/writing_a_custom_training_loop_in_tensorflow/)
42
- [Writing a training loop with JAX](/guides/writing_a_custom_training_loop_in_jax/)
43
- [Writing a training loop with PyTorch](/guides/writing_a_custom_training_loop_in_torch/)
44
45
In general, whether you are using built-in loops or writing your own, model training &
46
evaluation works strictly in the same way across every kind of Keras model --
47
Sequential models, models built with the Functional API, and models written from
48
scratch via model subclassing.
49
"""
50
51
"""
52
## API overview: a first end-to-end example
53
54
When passing data to the built-in training loops of a model, you should either use:
55
56
- NumPy arrays (if your data is small and fits in memory)
57
- Subclasses of `keras.utils.PyDataset`
58
- `tf.data.Dataset` objects
59
- PyTorch `DataLoader` instances
60
61
In the next few paragraphs, we'll use the MNIST dataset as NumPy arrays, in
62
order to demonstrate how to use optimizers, losses, and metrics. Afterwards, we'll
63
take a close look at each of the other options.
64
65
Let's consider the following model (here, we build in with the Functional API, but it
66
could be a Sequential model or a subclassed model as well):
67
"""
68
69
inputs = keras.Input(shape=(784,), name="digits")
70
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
71
x = layers.Dense(64, activation="relu", name="dense_2")(x)
72
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
73
74
model = keras.Model(inputs=inputs, outputs=outputs)
75
76
"""
77
Here's what the typical end-to-end workflow looks like, consisting of:
78
79
- Training
80
- Validation on a holdout set generated from the original training data
81
- Evaluation on the test data
82
83
We'll use MNIST data for this example.
84
"""
85
86
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
87
88
# Preprocess the data (these are NumPy arrays)
89
x_train = x_train.reshape(60000, 784).astype("float32") / 255
90
x_test = x_test.reshape(10000, 784).astype("float32") / 255
91
92
y_train = y_train.astype("float32")
93
y_test = y_test.astype("float32")
94
95
# Reserve 10,000 samples for validation
96
x_val = x_train[-10000:]
97
y_val = y_train[-10000:]
98
x_train = x_train[:-10000]
99
y_train = y_train[:-10000]
100
101
"""
102
We specify the training configuration (optimizer, loss, metrics):
103
"""
104
105
model.compile(
106
optimizer=keras.optimizers.RMSprop(), # Optimizer
107
# Loss function to minimize
108
loss=keras.losses.SparseCategoricalCrossentropy(),
109
# List of metrics to monitor
110
metrics=[keras.metrics.SparseCategoricalAccuracy()],
111
)
112
113
"""
114
We call `fit()`, which will train the model by slicing the data into "batches" of size
115
`batch_size`, and repeatedly iterating over the entire dataset for a given number of
116
`epochs`.
117
"""
118
119
print("Fit model on training data")
120
history = model.fit(
121
x_train,
122
y_train,
123
batch_size=64,
124
epochs=2,
125
# We pass some validation for
126
# monitoring validation loss and metrics
127
# at the end of each epoch
128
validation_data=(x_val, y_val),
129
)
130
131
"""
132
The returned `history` object holds a record of the loss values and metric values
133
during training:
134
"""
135
136
print(history.history)
137
138
"""
139
We evaluate the model on the test data via `evaluate()`:
140
"""
141
142
# Evaluate the model on the test data using `evaluate`
143
print("Evaluate on test data")
144
results = model.evaluate(x_test, y_test, batch_size=128)
145
print("test loss, test acc:", results)
146
147
# Generate predictions (probabilities -- the output of the last layer)
148
# on new data using `predict`
149
print("Generate predictions for 3 samples")
150
predictions = model.predict(x_test[:3])
151
print("predictions shape:", predictions.shape)
152
153
"""
154
Now, let's review each piece of this workflow in detail.
155
"""
156
157
"""
158
## The `compile()` method: specifying a loss, metrics, and an optimizer
159
160
To train a model with `fit()`, you need to specify a loss function, an optimizer, and
161
optionally, some metrics to monitor.
162
163
You pass these to the model as arguments to the `compile()` method:
164
"""
165
166
model.compile(
167
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
168
loss=keras.losses.SparseCategoricalCrossentropy(),
169
metrics=[keras.metrics.SparseCategoricalAccuracy()],
170
)
171
172
"""
173
The `metrics` argument should be a list -- your model can have any number of metrics.
174
175
If your model has multiple outputs, you can specify different losses and metrics for
176
each output, and you can modulate the contribution of each output to the total loss of
177
the model. You will find more details about this in the **Passing data to multi-input,
178
multi-output models** section.
179
180
Note that if you're satisfied with the default settings, in many cases the optimizer,
181
loss, and metrics can be specified via string identifiers as a shortcut:
182
"""
183
184
model.compile(
185
optimizer="rmsprop",
186
loss="sparse_categorical_crossentropy",
187
metrics=["sparse_categorical_accuracy"],
188
)
189
190
"""
191
For later reuse, let's put our model definition and compile step in functions; we will
192
call them several times across different examples in this guide.
193
"""
194
195
196
def get_uncompiled_model():
197
inputs = keras.Input(shape=(784,), name="digits")
198
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
199
x = layers.Dense(64, activation="relu", name="dense_2")(x)
200
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
201
model = keras.Model(inputs=inputs, outputs=outputs)
202
return model
203
204
205
def get_compiled_model():
206
model = get_uncompiled_model()
207
model.compile(
208
optimizer="rmsprop",
209
loss="sparse_categorical_crossentropy",
210
metrics=["sparse_categorical_accuracy"],
211
)
212
return model
213
214
215
"""
216
### Many built-in optimizers, losses, and metrics are available
217
218
In general, you won't have to create your own losses, metrics, or optimizers
219
from scratch, because what you need is likely to be already part of the Keras API:
220
221
Optimizers:
222
223
- `SGD()` (with or without momentum)
224
- `RMSprop()`
225
- `Adam()`
226
- etc.
227
228
Losses:
229
230
- `MeanSquaredError()`
231
- `KLDivergence()`
232
- `CosineSimilarity()`
233
- etc.
234
235
Metrics:
236
237
- `AUC()`
238
- `Precision()`
239
- `Recall()`
240
- etc.
241
"""
242
243
"""
244
### Custom losses
245
246
If you need to create a custom loss, Keras provides three ways to do so.
247
248
The first method involves creating a function that accepts inputs `y_true` and
249
`y_pred`. The following example shows a loss function that computes the mean squared
250
error between the real data and the predictions:
251
"""
252
253
254
def custom_mean_squared_error(y_true, y_pred):
255
return ops.mean(ops.square(y_true - y_pred), axis=-1)
256
257
258
model = get_uncompiled_model()
259
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)
260
261
# We need to one-hot encode the labels to use MSE
262
y_train_one_hot = ops.one_hot(y_train, num_classes=10)
263
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
264
265
"""
266
If you need a loss function that takes in parameters beside `y_true` and `y_pred`, you
267
can subclass the `keras.losses.Loss` class and implement the following two methods:
268
269
- `__init__(self)`: accept parameters to pass during the call of your loss function
270
- `call(self, y_true, y_pred)`: use the targets (y_true) and the model predictions
271
(y_pred) to compute the model's loss
272
273
Let's say you want to use mean squared error, but with an added term that
274
will de-incentivize prediction values far from 0.5 (we assume that the categorical
275
targets are one-hot encoded and take values between 0 and 1). This
276
creates an incentive for the model not to be too confident, which may help
277
reduce overfitting (we won't know if it works until we try!).
278
279
Here's how you would do it:
280
"""
281
282
283
class CustomMSE(keras.losses.Loss):
284
def __init__(self, regularization_factor=0.1, name="custom_mse"):
285
super().__init__(name=name)
286
self.regularization_factor = regularization_factor
287
288
def call(self, y_true, y_pred):
289
mse = ops.mean(ops.square(y_true - y_pred), axis=-1)
290
reg = ops.mean(ops.square(0.5 - y_pred), axis=-1)
291
return mse + reg * self.regularization_factor
292
293
294
model = get_uncompiled_model()
295
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
296
297
y_train_one_hot = ops.one_hot(y_train, num_classes=10)
298
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)
299
300
301
"""
302
### Custom metrics
303
304
If you need a metric that isn't part of the API, you can easily create custom metrics
305
by subclassing the `keras.metrics.Metric` class. You will need to implement 4
306
methods:
307
308
- `__init__(self)`, in which you will create state variables for your metric.
309
- `update_state(self, y_true, y_pred, sample_weight=None)`, which uses the targets
310
y_true and the model predictions y_pred to update the state variables.
311
- `result(self)`, which uses the state variables to compute the final results.
312
- `reset_state(self)`, which reinitializes the state of the metric.
313
314
State update and results computation are kept separate (in `update_state()` and
315
`result()`, respectively) because in some cases, the results computation might be very
316
expensive and would only be done periodically.
317
318
Here's a simple example showing how to implement a `CategoricalTruePositives` metric
319
that counts how many samples were correctly classified as belonging to a given class:
320
"""
321
322
323
class CategoricalTruePositives(keras.metrics.Metric):
324
def __init__(self, name="categorical_true_positives", **kwargs):
325
super().__init__(name=name, **kwargs)
326
self.true_positives = self.add_variable(
327
shape=(), name="ctp", initializer="zeros"
328
)
329
330
def update_state(self, y_true, y_pred, sample_weight=None):
331
y_pred = ops.reshape(ops.argmax(y_pred, axis=1), (-1, 1))
332
values = ops.cast(y_true, "int32") == ops.cast(y_pred, "int32")
333
values = ops.cast(values, "float32")
334
if sample_weight is not None:
335
sample_weight = ops.cast(sample_weight, "float32")
336
values = ops.multiply(values, sample_weight)
337
self.true_positives.assign_add(ops.sum(values))
338
339
def result(self):
340
return self.true_positives.value
341
342
def reset_state(self):
343
# The state of the metric will be reset at the start of each epoch.
344
self.true_positives.assign(0.0)
345
346
347
model = get_uncompiled_model()
348
model.compile(
349
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
350
loss=keras.losses.SparseCategoricalCrossentropy(),
351
metrics=[CategoricalTruePositives()],
352
)
353
model.fit(x_train, y_train, batch_size=64, epochs=3)
354
355
"""
356
### Handling losses and metrics that don't fit the standard signature
357
358
The overwhelming majority of losses and metrics can be computed from `y_true` and
359
`y_pred`, where `y_pred` is an output of your model -- but not all of them. For
360
instance, a regularization loss may only require the activation of a layer (there are
361
no targets in this case), and this activation may not be a model output.
362
363
In such cases, you can call `self.add_loss(loss_value)` from inside the call method of
364
a custom layer. Losses added in this way get added to the "main" loss during training
365
(the one passed to `compile()`). Here's a simple example that adds activity
366
regularization (note that activity regularization is built-in in all Keras layers --
367
this layer is just for the sake of providing a concrete example):
368
"""
369
370
371
class ActivityRegularizationLayer(layers.Layer):
372
def call(self, inputs):
373
self.add_loss(ops.sum(inputs) * 0.1)
374
return inputs # Pass-through layer.
375
376
377
inputs = keras.Input(shape=(784,), name="digits")
378
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
379
380
# Insert activity regularization as a layer
381
x = ActivityRegularizationLayer()(x)
382
383
x = layers.Dense(64, activation="relu", name="dense_2")(x)
384
outputs = layers.Dense(10, name="predictions")(x)
385
386
model = keras.Model(inputs=inputs, outputs=outputs)
387
model.compile(
388
optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
389
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
390
)
391
392
# The displayed loss will be much higher than before
393
# due to the regularization component.
394
model.fit(x_train, y_train, batch_size=64, epochs=1)
395
396
"""
397
Note that when you pass losses via `add_loss()`, it becomes possible to call
398
`compile()` without a loss function, since the model already has a loss to minimize.
399
400
Consider the following `LogisticEndpoint` layer: it takes as inputs
401
targets & logits, and it tracks a crossentropy loss via `add_loss()`.
402
"""
403
404
405
class LogisticEndpoint(keras.layers.Layer):
406
def __init__(self, name=None):
407
super().__init__(name=name)
408
self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
409
410
def call(self, targets, logits, sample_weights=None):
411
# Compute the training-time loss value and add it
412
# to the layer using `self.add_loss()`.
413
loss = self.loss_fn(targets, logits, sample_weights)
414
self.add_loss(loss)
415
416
# Return the inference-time prediction tensor (for `.predict()`).
417
return ops.softmax(logits)
418
419
420
"""
421
You can use it in a model with two inputs (input data & targets), compiled without a
422
`loss` argument, like this:
423
"""
424
425
inputs = keras.Input(shape=(3,), name="inputs")
426
targets = keras.Input(shape=(10,), name="targets")
427
logits = keras.layers.Dense(10)(inputs)
428
predictions = LogisticEndpoint(name="predictions")(targets, logits)
429
430
model = keras.Model(inputs=[inputs, targets], outputs=predictions)
431
model.compile(optimizer="adam") # No loss argument!
432
433
data = {
434
"inputs": np.random.random((3, 3)),
435
"targets": np.random.random((3, 10)),
436
}
437
model.fit(data)
438
439
"""
440
For more information about training multi-input models, see the section **Passing data
441
to multi-input, multi-output models**.
442
"""
443
444
"""
445
### Automatically setting apart a validation holdout set
446
447
In the first end-to-end example you saw, we used the `validation_data` argument to pass
448
a tuple of NumPy arrays `(x_val, y_val)` to the model for evaluating a validation loss
449
and validation metrics at the end of each epoch.
450
451
Here's another option: the argument `validation_split` allows you to automatically
452
reserve part of your training data for validation. The argument value represents the
453
fraction of the data to be reserved for validation, so it should be set to a number
454
higher than 0 and lower than 1. For instance, `validation_split=0.2` means "use 20% of
455
the data for validation", and `validation_split=0.6` means "use 60% of the data for
456
validation".
457
458
The way the validation is computed is by taking the last x% samples of the arrays
459
received by the `fit()` call, before any shuffling.
460
461
Note that you can only use `validation_split` when training with NumPy data.
462
"""
463
464
model = get_compiled_model()
465
model.fit(x_train, y_train, batch_size=64, validation_split=0.2, epochs=1)
466
467
"""
468
## Training & evaluation using `tf.data` Datasets
469
470
In the past few paragraphs, you've seen how to handle losses, metrics, and optimizers,
471
and you've seen how to use the `validation_data` and `validation_split` arguments in
472
`fit()`, when your data is passed as NumPy arrays.
473
474
Another option is to use an iterator-like, such as a `tf.data.Dataset`, a
475
PyTorch `DataLoader`, or a Keras `PyDataset`. Let's take look at the former.
476
477
The `tf.data` API is a set of utilities in TensorFlow 2.0 for loading and preprocessing
478
data in a way that's fast and scalable. For a complete guide about creating `Datasets`,
479
see the [tf.data documentation](https://www.tensorflow.org/guide/data).
480
481
**You can use `tf.data` to train your Keras
482
models regardless of the backend you're using --
483
whether it's JAX, PyTorch, or TensorFlow.**
484
You can pass a `Dataset` instance directly to the methods `fit()`, `evaluate()`, and
485
`predict()`:
486
"""
487
488
model = get_compiled_model()
489
490
# First, let's create a training Dataset instance.
491
# For the sake of our example, we'll use the same MNIST data as before.
492
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
493
# Shuffle and slice the dataset.
494
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
495
496
# Now we get a test dataset.
497
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
498
test_dataset = test_dataset.batch(64)
499
500
# Since the dataset already takes care of batching,
501
# we don't pass a `batch_size` argument.
502
model.fit(train_dataset, epochs=3)
503
504
# You can also evaluate or predict on a dataset.
505
print("Evaluate")
506
result = model.evaluate(test_dataset)
507
dict(zip(model.metrics_names, result))
508
509
"""
510
Note that the Dataset is reset at the end of each epoch, so it can be reused of the
511
next epoch.
512
513
If you want to run training only on a specific number of batches from this Dataset, you
514
can pass the `steps_per_epoch` argument, which specifies how many training steps the
515
model should run using this Dataset before moving on to the next epoch.
516
"""
517
518
model = get_compiled_model()
519
520
# Prepare the training dataset
521
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
522
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
523
524
# Only use the 100 batches per epoch (that's 64 * 100 samples)
525
model.fit(train_dataset, epochs=3, steps_per_epoch=100)
526
527
"""
528
You can also pass a `Dataset` instance as the `validation_data` argument in `fit()`:
529
"""
530
531
model = get_compiled_model()
532
533
# Prepare the training dataset
534
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
535
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
536
537
# Prepare the validation dataset
538
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
539
val_dataset = val_dataset.batch(64)
540
541
model.fit(train_dataset, epochs=1, validation_data=val_dataset)
542
543
"""
544
At the end of each epoch, the model will iterate over the validation dataset and
545
compute the validation loss and validation metrics.
546
547
If you want to run validation only on a specific number of batches from this dataset,
548
you can pass the `validation_steps` argument, which specifies how many validation
549
steps the model should run with the validation dataset before interrupting validation
550
and moving on to the next epoch:
551
"""
552
553
model = get_compiled_model()
554
555
# Prepare the training dataset
556
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
557
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
558
559
# Prepare the validation dataset
560
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
561
val_dataset = val_dataset.batch(64)
562
563
model.fit(
564
train_dataset,
565
epochs=1,
566
# Only run validation using the first 10 batches of the dataset
567
# using the `validation_steps` argument
568
validation_data=val_dataset,
569
validation_steps=10,
570
)
571
572
"""
573
Note that the validation dataset will be reset after each use (so that you will always
574
be evaluating on the same samples from epoch to epoch).
575
576
The argument `validation_split` (generating a holdout set from the training data) is
577
not supported when training from `Dataset` objects, since this feature requires the
578
ability to index the samples of the datasets, which is not possible in general with
579
the `Dataset` API.
580
"""
581
582
"""
583
## Training & evaluation using `PyDataset` instances
584
585
`keras.utils.PyDataset` is a utility that you can subclass to obtain
586
a Python generator with two important properties:
587
588
- It works well with multiprocessing.
589
- It can be shuffled (e.g. when passing `shuffle=True` in `fit()`).
590
591
A `PyDataset` must implement two methods:
592
593
- `__getitem__`
594
- `__len__`
595
596
The method `__getitem__` should return a complete batch.
597
If you want to modify your dataset between epochs, you may implement `on_epoch_end`.
598
599
Here's a quick example:
600
"""
601
602
603
class ExamplePyDataset(keras.utils.PyDataset):
604
def __init__(self, x, y, batch_size, **kwargs):
605
super().__init__(**kwargs)
606
self.x = x
607
self.y = y
608
self.batch_size = batch_size
609
610
def __len__(self):
611
return int(np.ceil(len(self.x) / float(self.batch_size)))
612
613
def __getitem__(self, idx):
614
batch_x = self.x[idx * self.batch_size : (idx + 1) * self.batch_size]
615
batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
616
return batch_x, batch_y
617
618
619
train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32)
620
val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32)
621
622
"""
623
To fit the model, pass the dataset instead as the `x` argument (no need for a `y`
624
argument since the dataset includes the targets), and pass the validation dataset
625
as the `validation_data` argument. And no need for the `batch_size` argument, since
626
the dataset is already batched!
627
"""
628
629
model = get_compiled_model()
630
model.fit(train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1)
631
632
"""
633
Evaluating the model is just as easy:
634
"""
635
636
model.evaluate(val_py_dataset)
637
638
"""
639
Importantly, `PyDataset` objects support three common constructor arguments
640
that handle the parallel processing configuration:
641
642
- `workers`: Number of workers to use in multithreading or
643
multiprocessing. Typically, you'd set it to the number of
644
cores on your CPU.
645
- `use_multiprocessing`: Whether to use Python multiprocessing for
646
parallelism. Setting this to `True` means that your
647
dataset will be replicated in multiple forked processes.
648
This is necessary to gain compute-level (rather than I/O level)
649
benefits from parallelism. However it can only be set to
650
`True` if your dataset can be safely pickled.
651
- `max_queue_size`: Maximum number of batches to keep in the queue
652
when iterating over the dataset in a multithreaded or
653
multipricessed setting.
654
You can reduce this value to reduce the CPU memory consumption of
655
your dataset. It defaults to 10.
656
657
By default, multiprocessing is disabled (`use_multiprocessing=False`) and only
658
one thread is used. You should make sure to only turn on `use_multiprocessing` if
659
your code is running inside a Python `if __name__ == "__main__":` block in order
660
to avoid issues.
661
662
Here's a 4-thread, non-multiprocessed example:
663
"""
664
665
train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32, workers=4)
666
val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32, workers=4)
667
668
model = get_compiled_model()
669
model.fit(train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1)
670
671
"""
672
## Training & evaluation using PyTorch `DataLoader` objects
673
674
All built-in training and evaluation APIs are also compatible with `torch.utils.data.Dataset` and
675
`torch.utils.data.DataLoader` objects -- regardless of whether you're using the PyTorch backend,
676
or the JAX or TensorFlow backends. Let's take a look at a simple example.
677
678
Unlike `PyDataset` which are batch-centric, PyTorch `Dataset` objects are sample-centric:
679
the `__len__` method returns the number of samples,
680
and the `__getitem__` method returns a specific sample.
681
"""
682
683
684
class ExampleTorchDataset(torch.utils.data.Dataset):
685
def __init__(self, x, y):
686
self.x = x
687
self.y = y
688
689
def __len__(self):
690
return len(self.x)
691
692
def __getitem__(self, idx):
693
return self.x[idx], self.y[idx]
694
695
696
train_torch_dataset = ExampleTorchDataset(x_train, y_train)
697
val_torch_dataset = ExampleTorchDataset(x_val, y_val)
698
699
"""
700
To use a PyTorch Dataset, you need to wrap it into a `Dataloader` which takes care
701
of batching and shuffling:
702
"""
703
704
train_dataloader = torch.utils.data.DataLoader(
705
train_torch_dataset, batch_size=32, shuffle=True
706
)
707
val_dataloader = torch.utils.data.DataLoader(
708
val_torch_dataset, batch_size=32, shuffle=True
709
)
710
711
"""
712
Now you can use them in the Keras API just like any other iterator:
713
"""
714
715
model = get_compiled_model()
716
model.fit(train_dataloader, batch_size=64, validation_data=val_dataloader, epochs=1)
717
model.evaluate(val_dataloader)
718
719
"""
720
## Using sample weighting and class weighting
721
722
With the default settings the weight of a sample is decided by its frequency
723
in the dataset. There are two methods to weight the data, independent of
724
sample frequency:
725
726
* Class weights
727
* Sample weights
728
"""
729
730
"""
731
### Class weights
732
733
This is set by passing a dictionary to the `class_weight` argument to
734
`Model.fit()`. This dictionary maps class indices to the weight that should
735
be used for samples belonging to this class.
736
737
This can be used to balance classes without resampling, or to train a
738
model that gives more importance to a particular class.
739
740
For instance, if class "0" is half as represented as class "1" in your data,
741
you could use `Model.fit(..., class_weight={0: 1., 1: 0.5})`.
742
"""
743
744
"""
745
Here's a NumPy example where we use class weights or sample weights to
746
give more importance to the correct classification of class #5 (which
747
is the digit "5" in the MNIST dataset).
748
"""
749
750
class_weight = {
751
0: 1.0,
752
1: 1.0,
753
2: 1.0,
754
3: 1.0,
755
4: 1.0,
756
# Set weight "2" for class "5",
757
# making this class 2x more important
758
5: 2.0,
759
6: 1.0,
760
7: 1.0,
761
8: 1.0,
762
9: 1.0,
763
}
764
765
print("Fit with class weight")
766
model = get_compiled_model()
767
model.fit(x_train, y_train, class_weight=class_weight, batch_size=64, epochs=1)
768
769
"""
770
### Sample weights
771
772
For fine grained control, or if you are not building a classifier,
773
you can use "sample weights".
774
775
- When training from NumPy data: Pass the `sample_weight`
776
argument to `Model.fit()`.
777
- When training from `tf.data` or any other sort of iterator:
778
Yield `(input_batch, label_batch, sample_weight_batch)` tuples.
779
780
A "sample weights" array is an array of numbers that specify how much weight
781
each sample in a batch should have in computing the total loss. It is commonly
782
used in imbalanced classification problems (the idea being to give more weight
783
to rarely-seen classes).
784
785
When the weights used are ones and zeros, the array can be used as a *mask* for
786
the loss function (entirely discarding the contribution of certain samples to
787
the total loss).
788
"""
789
790
sample_weight = np.ones(shape=(len(y_train),))
791
sample_weight[y_train == 5] = 2.0
792
793
print("Fit with sample weight")
794
model = get_compiled_model()
795
model.fit(x_train, y_train, sample_weight=sample_weight, batch_size=64, epochs=1)
796
797
"""
798
Here's a matching `Dataset` example:
799
"""
800
801
sample_weight = np.ones(shape=(len(y_train),))
802
sample_weight[y_train == 5] = 2.0
803
804
# Create a Dataset that includes sample weights
805
# (3rd element in the return tuple).
806
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))
807
808
# Shuffle and slice the dataset.
809
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
810
811
model = get_compiled_model()
812
model.fit(train_dataset, epochs=1)
813
814
"""
815
## Passing data to multi-input, multi-output models
816
817
In the previous examples, we were considering a model with a single input (a tensor of
818
shape `(764,)`) and a single output (a prediction tensor of shape `(10,)`). But what
819
about models that have multiple inputs or outputs?
820
821
Consider the following model, which has an image input of shape `(32, 32, 3)` (that's
822
`(height, width, channels)`) and a time series input of shape `(None, 10)` (that's
823
`(timesteps, features)`). Our model will have two outputs computed from the
824
combination of these inputs: a "score" (of shape `(1,)`) and a probability
825
distribution over five classes (of shape `(5,)`).
826
"""
827
828
image_input = keras.Input(shape=(32, 32, 3), name="img_input")
829
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")
830
831
x1 = layers.Conv2D(3, 3)(image_input)
832
x1 = layers.GlobalMaxPooling2D()(x1)
833
834
x2 = layers.Conv1D(3, 3)(timeseries_input)
835
x2 = layers.GlobalMaxPooling1D()(x2)
836
837
x = layers.concatenate([x1, x2])
838
839
score_output = layers.Dense(1, name="score_output")(x)
840
class_output = layers.Dense(5, name="class_output")(x)
841
842
model = keras.Model(
843
inputs=[image_input, timeseries_input], outputs=[score_output, class_output]
844
)
845
846
"""
847
Let's plot this model, so you can clearly see what we're doing here (note that the
848
shapes shown in the plot are batch shapes, rather than per-sample shapes).
849
"""
850
851
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
852
853
"""
854
At compilation time, we can specify different losses to different outputs, by passing
855
the loss functions as a list:
856
"""
857
858
model.compile(
859
optimizer=keras.optimizers.RMSprop(1e-3),
860
loss=[
861
keras.losses.MeanSquaredError(),
862
keras.losses.CategoricalCrossentropy(),
863
],
864
)
865
866
"""
867
If we only passed a single loss function to the model, the same loss function would be
868
applied to every output (which is not appropriate here).
869
870
Likewise for metrics:
871
"""
872
873
model.compile(
874
optimizer=keras.optimizers.RMSprop(1e-3),
875
loss=[
876
keras.losses.MeanSquaredError(),
877
keras.losses.CategoricalCrossentropy(),
878
],
879
metrics=[
880
[
881
keras.metrics.MeanAbsolutePercentageError(),
882
keras.metrics.MeanAbsoluteError(),
883
],
884
[keras.metrics.CategoricalAccuracy()],
885
],
886
)
887
888
"""
889
Since we gave names to our output layers, we could also specify per-output losses and
890
metrics via a dict:
891
"""
892
893
model.compile(
894
optimizer=keras.optimizers.RMSprop(1e-3),
895
loss={
896
"score_output": keras.losses.MeanSquaredError(),
897
"class_output": keras.losses.CategoricalCrossentropy(),
898
},
899
metrics={
900
"score_output": [
901
keras.metrics.MeanAbsolutePercentageError(),
902
keras.metrics.MeanAbsoluteError(),
903
],
904
"class_output": [keras.metrics.CategoricalAccuracy()],
905
},
906
)
907
908
"""
909
We recommend the use of explicit names and dicts if you have more than 2 outputs.
910
911
It's possible to give different weights to different output-specific losses (for
912
instance, one might wish to privilege the "score" loss in our example, by giving to 2x
913
the importance of the class loss), using the `loss_weights` argument:
914
"""
915
916
model.compile(
917
optimizer=keras.optimizers.RMSprop(1e-3),
918
loss={
919
"score_output": keras.losses.MeanSquaredError(),
920
"class_output": keras.losses.CategoricalCrossentropy(),
921
},
922
metrics={
923
"score_output": [
924
keras.metrics.MeanAbsolutePercentageError(),
925
keras.metrics.MeanAbsoluteError(),
926
],
927
"class_output": [keras.metrics.CategoricalAccuracy()],
928
},
929
loss_weights={"score_output": 2.0, "class_output": 1.0},
930
)
931
932
"""
933
You could also choose not to compute a loss for certain outputs, if these outputs are
934
meant for prediction but not for training:
935
"""
936
937
# List loss version
938
model.compile(
939
optimizer=keras.optimizers.RMSprop(1e-3),
940
loss=[None, keras.losses.CategoricalCrossentropy()],
941
)
942
943
# Or dict loss version
944
model.compile(
945
optimizer=keras.optimizers.RMSprop(1e-3),
946
loss={"class_output": keras.losses.CategoricalCrossentropy()},
947
)
948
949
"""
950
Passing data to a multi-input or multi-output model in `fit()` works in a similar way as
951
specifying a loss function in compile: you can pass **lists of NumPy arrays** (with
952
1:1 mapping to the outputs that received a loss function) or **dicts mapping output
953
names to NumPy arrays**.
954
"""
955
956
model.compile(
957
optimizer=keras.optimizers.RMSprop(1e-3),
958
loss=[
959
keras.losses.MeanSquaredError(),
960
keras.losses.CategoricalCrossentropy(),
961
],
962
)
963
964
# Generate dummy NumPy data
965
img_data = np.random.random_sample(size=(100, 32, 32, 3))
966
ts_data = np.random.random_sample(size=(100, 20, 10))
967
score_targets = np.random.random_sample(size=(100, 1))
968
class_targets = np.random.random_sample(size=(100, 5))
969
970
# Fit on lists
971
model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)
972
973
# Alternatively, fit on dicts
974
model.fit(
975
{"img_input": img_data, "ts_input": ts_data},
976
{"score_output": score_targets, "class_output": class_targets},
977
batch_size=32,
978
epochs=1,
979
)
980
981
"""
982
Here's the `Dataset` use case: similarly as what we did for NumPy arrays, the `Dataset`
983
should return a tuple of dicts.
984
"""
985
986
train_dataset = tf.data.Dataset.from_tensor_slices(
987
(
988
{"img_input": img_data, "ts_input": ts_data},
989
{"score_output": score_targets, "class_output": class_targets},
990
)
991
)
992
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
993
994
model.fit(train_dataset, epochs=1)
995
996
"""
997
## Using callbacks
998
999
Callbacks in Keras are objects that are called at different points during training (at
1000
the start of an epoch, at the end of a batch, at the end of an epoch, etc.). They
1001
can be used to implement certain behaviors, such as:
1002
1003
- Doing validation at different points during training (beyond the built-in per-epoch
1004
validation)
1005
- Checkpointing the model at regular intervals or when it exceeds a certain accuracy
1006
threshold
1007
- Changing the learning rate of the model when training seems to be plateauing
1008
- Doing fine-tuning of the top layers when training seems to be plateauing
1009
- Sending email or instant message notifications when training ends or where a certain
1010
performance threshold is exceeded
1011
- Etc.
1012
1013
Callbacks can be passed as a list to your call to `fit()`:
1014
"""
1015
1016
model = get_compiled_model()
1017
1018
callbacks = [
1019
keras.callbacks.EarlyStopping(
1020
# Stop training when `val_loss` is no longer improving
1021
monitor="val_loss",
1022
# "no longer improving" being defined as "no better than 1e-2 less"
1023
min_delta=1e-2,
1024
# "no longer improving" being further defined as "for at least 2 epochs"
1025
patience=2,
1026
verbose=1,
1027
)
1028
]
1029
model.fit(
1030
x_train,
1031
y_train,
1032
epochs=20,
1033
batch_size=64,
1034
callbacks=callbacks,
1035
validation_split=0.2,
1036
)
1037
1038
"""
1039
### Many built-in callbacks are available
1040
1041
There are many built-in callbacks already available in Keras, such as:
1042
1043
- `ModelCheckpoint`: Periodically save the model.
1044
- `EarlyStopping`: Stop training when training is no longer improving the validation
1045
metrics.
1046
- `TensorBoard`: periodically write model logs that can be visualized in
1047
[TensorBoard](https://www.tensorflow.org/tensorboard) (more details in the section
1048
"Visualization").
1049
- `CSVLogger`: streams loss and metrics data to a CSV file.
1050
- etc.
1051
1052
See the [callbacks documentation](/api/callbacks/) for the complete list.
1053
1054
### Writing your own callback
1055
1056
You can create a custom callback by extending the base class
1057
`keras.callbacks.Callback`. A callback has access to its associated model through the
1058
class property `self.model`.
1059
1060
Make sure to read the
1061
[complete guide to writing custom callbacks](/guides/writing_your_own_callbacks/).
1062
1063
Here's a simple example saving a list of per-batch loss values during training:
1064
"""
1065
1066
1067
class LossHistory(keras.callbacks.Callback):
1068
def on_train_begin(self, logs):
1069
self.per_batch_losses = []
1070
1071
def on_batch_end(self, batch, logs):
1072
self.per_batch_losses.append(logs.get("loss"))
1073
1074
1075
"""
1076
## Checkpointing models
1077
1078
When you're training model on relatively large datasets, it's crucial to save
1079
checkpoints of your model at frequent intervals.
1080
1081
The easiest way to achieve this is with the `ModelCheckpoint` callback:
1082
"""
1083
1084
model = get_compiled_model()
1085
1086
callbacks = [
1087
keras.callbacks.ModelCheckpoint(
1088
# Path where to save the model
1089
# The two parameters below mean that we will overwrite
1090
# the current checkpoint if and only if
1091
# the `val_loss` score has improved.
1092
# The saved model name will include the current epoch.
1093
filepath="mymodel_{epoch}.keras",
1094
save_best_only=True, # Only save a model if `val_loss` has improved.
1095
monitor="val_loss",
1096
verbose=1,
1097
)
1098
]
1099
model.fit(
1100
x_train,
1101
y_train,
1102
epochs=2,
1103
batch_size=64,
1104
callbacks=callbacks,
1105
validation_split=0.2,
1106
)
1107
1108
"""
1109
The `ModelCheckpoint` callback can be used to implement fault-tolerance:
1110
the ability to restart training from the last saved state of the model in case training
1111
gets randomly interrupted. Here's a basic example:
1112
"""
1113
1114
# Prepare a directory to store all the checkpoints.
1115
checkpoint_dir = "./ckpt"
1116
if not os.path.exists(checkpoint_dir):
1117
os.makedirs(checkpoint_dir)
1118
1119
1120
def make_or_restore_model():
1121
# Either restore the latest model, or create a fresh one
1122
# if there is no checkpoint available.
1123
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
1124
if checkpoints:
1125
latest_checkpoint = max(checkpoints, key=os.path.getctime)
1126
print("Restoring from", latest_checkpoint)
1127
return keras.models.load_model(latest_checkpoint)
1128
print("Creating a new model")
1129
return get_compiled_model()
1130
1131
1132
model = make_or_restore_model()
1133
callbacks = [
1134
# This callback saves the model every 100 batches.
1135
# We include the training loss in the saved model name.
1136
keras.callbacks.ModelCheckpoint(
1137
filepath=checkpoint_dir + "/model-loss={loss:.2f}.keras", save_freq=100
1138
)
1139
]
1140
model.fit(x_train, y_train, epochs=1, callbacks=callbacks)
1141
1142
"""
1143
You call also write your own callback for saving and restoring models.
1144
1145
For a complete guide on serialization and saving, see the
1146
[guide to saving and serializing Models](/guides/serialization_and_saving/).
1147
"""
1148
1149
"""
1150
## Using learning rate schedules
1151
1152
A common pattern when training deep learning models is to gradually reduce the learning
1153
as training progresses. This is generally known as "learning rate decay".
1154
1155
The learning decay schedule could be static (fixed in advance, as a function of the
1156
current epoch or the current batch index), or dynamic (responding to the current
1157
behavior of the model, in particular the validation loss).
1158
1159
### Passing a schedule to an optimizer
1160
1161
You can easily use a static learning rate decay schedule by passing a schedule object
1162
as the `learning_rate` argument in your optimizer:
1163
"""
1164
1165
initial_learning_rate = 0.1
1166
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
1167
initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
1168
)
1169
1170
optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)
1171
1172
"""
1173
Several built-in schedules are available: `ExponentialDecay`, `PiecewiseConstantDecay`,
1174
`PolynomialDecay`, and `InverseTimeDecay`.
1175
1176
### Using callbacks to implement a dynamic learning rate schedule
1177
1178
A dynamic learning rate schedule (for instance, decreasing the learning rate when the
1179
validation loss is no longer improving) cannot be achieved with these schedule objects,
1180
since the optimizer does not have access to validation metrics.
1181
1182
However, callbacks do have access to all metrics, including validation metrics! You can
1183
thus achieve this pattern by using a callback that modifies the current learning rate
1184
on the optimizer. In fact, this is even built-in as the `ReduceLROnPlateau` callback.
1185
"""
1186
1187
"""
1188
## Visualizing loss and metrics during training with TensorBoard
1189
1190
The best way to keep an eye on your model during training is to use
1191
[TensorBoard](https://www.tensorflow.org/tensorboard) -- a browser-based application
1192
that you can run locally that provides you with:
1193
1194
- Live plots of the loss and metrics for training and evaluation
1195
- (optionally) Visualizations of the histograms of your layer activations
1196
- (optionally) 3D visualizations of the embedding spaces learned by your `Embedding`
1197
layers
1198
1199
If you have installed TensorFlow with pip, you should be able to launch TensorBoard
1200
from the command line:
1201
1202
```
1203
tensorboard --logdir=/full_path_to_your_logs
1204
```
1205
"""
1206
1207
"""
1208
### Using the TensorBoard callback
1209
1210
The easiest way to use TensorBoard with a Keras model and the `fit()` method is the
1211
`TensorBoard` callback.
1212
1213
In the simplest case, just specify where you want the callback to write logs, and
1214
you're good to go:
1215
"""
1216
1217
keras.callbacks.TensorBoard(
1218
log_dir="/full_path_to_your_logs",
1219
histogram_freq=0, # How often to log histogram visualizations
1220
embeddings_freq=0, # How often to log embedding visualizations
1221
update_freq="epoch",
1222
) # How often to write logs (default: once per epoch)
1223
1224
"""
1225
For more information, see the
1226
[documentation for the `TensorBoard` callback](https://keras.io/api/callbacks/tensorboard/).
1227
"""
1228
1229