Path: blob/master/guides/training_with_built_in_methods.py
3273 views
"""1Title: Training & evaluation with the built-in methods2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/03/014Last modified: 2023/06/255Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.6Accelerator: GPU7"""89"""10## Setup11"""1213# We import torch & TF so as to use torch Dataloaders & tf.data.Datasets.14import torch15import tensorflow as tf1617import os18import numpy as np19import keras20from keras import layers21from keras import ops2223"""24## Introduction2526This guide covers training, evaluation, and prediction (inference) models27when using built-in APIs for training & validation (such as `Model.fit()`,28`Model.evaluate()` and `Model.predict()`).2930If you are interested in leveraging `fit()` while specifying your31own training step function, see the guides on customizing what happens in `fit()`:3233- [Writing a custom train step with TensorFlow](/guides/custom_train_step_in_tensorflow/)34- [Writing a custom train step with JAX](/guides/custom_train_step_in_jax/)35- [Writing a custom train step with PyTorch](/guides/custom_train_step_in_torch/)3637If you are interested in writing your own training & evaluation loops from38scratch, see the guides on writing training loops:3940- [Writing a training loop with TensorFlow](/guides/writing_a_custom_training_loop_in_tensorflow/)41- [Writing a training loop with JAX](/guides/writing_a_custom_training_loop_in_jax/)42- [Writing a training loop with PyTorch](/guides/writing_a_custom_training_loop_in_torch/)4344In general, whether you are using built-in loops or writing your own, model training &45evaluation works strictly in the same way across every kind of Keras model --46Sequential models, models built with the Functional API, and models written from47scratch via model subclassing.48"""4950"""51## API overview: a first end-to-end example5253When passing data to the built-in training loops of a model, you should either use:5455- NumPy arrays (if your data is small and fits in memory)56- Subclasses of `keras.utils.PyDataset`57- `tf.data.Dataset` objects58- PyTorch `DataLoader` instances5960In the next few paragraphs, we'll use the MNIST dataset as NumPy arrays, in61order to demonstrate how to use optimizers, losses, and metrics. Afterwards, we'll62take a close look at each of the other options.6364Let's consider the following model (here, we build in with the Functional API, but it65could be a Sequential model or a subclassed model as well):66"""6768inputs = keras.Input(shape=(784,), name="digits")69x = layers.Dense(64, activation="relu", name="dense_1")(inputs)70x = layers.Dense(64, activation="relu", name="dense_2")(x)71outputs = layers.Dense(10, activation="softmax", name="predictions")(x)7273model = keras.Model(inputs=inputs, outputs=outputs)7475"""76Here's what the typical end-to-end workflow looks like, consisting of:7778- Training79- Validation on a holdout set generated from the original training data80- Evaluation on the test data8182We'll use MNIST data for this example.83"""8485(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8687# Preprocess the data (these are NumPy arrays)88x_train = x_train.reshape(60000, 784).astype("float32") / 25589x_test = x_test.reshape(10000, 784).astype("float32") / 2559091y_train = y_train.astype("float32")92y_test = y_test.astype("float32")9394# Reserve 10,000 samples for validation95x_val = x_train[-10000:]96y_val = y_train[-10000:]97x_train = x_train[:-10000]98y_train = y_train[:-10000]99100"""101We specify the training configuration (optimizer, loss, metrics):102"""103104model.compile(105optimizer=keras.optimizers.RMSprop(), # Optimizer106# Loss function to minimize107loss=keras.losses.SparseCategoricalCrossentropy(),108# List of metrics to monitor109metrics=[keras.metrics.SparseCategoricalAccuracy()],110)111112"""113We call `fit()`, which will train the model by slicing the data into "batches" of size114`batch_size`, and repeatedly iterating over the entire dataset for a given number of115`epochs`.116"""117118print("Fit model on training data")119history = model.fit(120x_train,121y_train,122batch_size=64,123epochs=2,124# We pass some validation for125# monitoring validation loss and metrics126# at the end of each epoch127validation_data=(x_val, y_val),128)129130"""131The returned `history` object holds a record of the loss values and metric values132during training:133"""134135print(history.history)136137"""138We evaluate the model on the test data via `evaluate()`:139"""140141# Evaluate the model on the test data using `evaluate`142print("Evaluate on test data")143results = model.evaluate(x_test, y_test, batch_size=128)144print("test loss, test acc:", results)145146# Generate predictions (probabilities -- the output of the last layer)147# on new data using `predict`148print("Generate predictions for 3 samples")149predictions = model.predict(x_test[:3])150print("predictions shape:", predictions.shape)151152"""153Now, let's review each piece of this workflow in detail.154"""155156"""157## The `compile()` method: specifying a loss, metrics, and an optimizer158159To train a model with `fit()`, you need to specify a loss function, an optimizer, and160optionally, some metrics to monitor.161162You pass these to the model as arguments to the `compile()` method:163"""164165model.compile(166optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),167loss=keras.losses.SparseCategoricalCrossentropy(),168metrics=[keras.metrics.SparseCategoricalAccuracy()],169)170171"""172The `metrics` argument should be a list -- your model can have any number of metrics.173174If your model has multiple outputs, you can specify different losses and metrics for175each output, and you can modulate the contribution of each output to the total loss of176the model. You will find more details about this in the **Passing data to multi-input,177multi-output models** section.178179Note that if you're satisfied with the default settings, in many cases the optimizer,180loss, and metrics can be specified via string identifiers as a shortcut:181"""182183model.compile(184optimizer="rmsprop",185loss="sparse_categorical_crossentropy",186metrics=["sparse_categorical_accuracy"],187)188189"""190For later reuse, let's put our model definition and compile step in functions; we will191call them several times across different examples in this guide.192"""193194195def get_uncompiled_model():196inputs = keras.Input(shape=(784,), name="digits")197x = layers.Dense(64, activation="relu", name="dense_1")(inputs)198x = layers.Dense(64, activation="relu", name="dense_2")(x)199outputs = layers.Dense(10, activation="softmax", name="predictions")(x)200model = keras.Model(inputs=inputs, outputs=outputs)201return model202203204def get_compiled_model():205model = get_uncompiled_model()206model.compile(207optimizer="rmsprop",208loss="sparse_categorical_crossentropy",209metrics=["sparse_categorical_accuracy"],210)211return model212213214"""215### Many built-in optimizers, losses, and metrics are available216217In general, you won't have to create your own losses, metrics, or optimizers218from scratch, because what you need is likely to be already part of the Keras API:219220Optimizers:221222- `SGD()` (with or without momentum)223- `RMSprop()`224- `Adam()`225- etc.226227Losses:228229- `MeanSquaredError()`230- `KLDivergence()`231- `CosineSimilarity()`232- etc.233234Metrics:235236- `AUC()`237- `Precision()`238- `Recall()`239- etc.240"""241242"""243### Custom losses244245If you need to create a custom loss, Keras provides three ways to do so.246247The first method involves creating a function that accepts inputs `y_true` and248`y_pred`. The following example shows a loss function that computes the mean squared249error between the real data and the predictions:250"""251252253def custom_mean_squared_error(y_true, y_pred):254return ops.mean(ops.square(y_true - y_pred), axis=-1)255256257model = get_uncompiled_model()258model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)259260# We need to one-hot encode the labels to use MSE261y_train_one_hot = ops.one_hot(y_train, num_classes=10)262model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)263264"""265If you need a loss function that takes in parameters beside `y_true` and `y_pred`, you266can subclass the `keras.losses.Loss` class and implement the following two methods:267268- `__init__(self)`: accept parameters to pass during the call of your loss function269- `call(self, y_true, y_pred)`: use the targets (y_true) and the model predictions270(y_pred) to compute the model's loss271272Let's say you want to use mean squared error, but with an added term that273will de-incentivize prediction values far from 0.5 (we assume that the categorical274targets are one-hot encoded and take values between 0 and 1). This275creates an incentive for the model not to be too confident, which may help276reduce overfitting (we won't know if it works until we try!).277278Here's how you would do it:279"""280281282class CustomMSE(keras.losses.Loss):283def __init__(self, regularization_factor=0.1, name="custom_mse"):284super().__init__(name=name)285self.regularization_factor = regularization_factor286287def call(self, y_true, y_pred):288mse = ops.mean(ops.square(y_true - y_pred), axis=-1)289reg = ops.mean(ops.square(0.5 - y_pred), axis=-1)290return mse + reg * self.regularization_factor291292293model = get_uncompiled_model()294model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())295296y_train_one_hot = ops.one_hot(y_train, num_classes=10)297model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)298299300"""301### Custom metrics302303If you need a metric that isn't part of the API, you can easily create custom metrics304by subclassing the `keras.metrics.Metric` class. You will need to implement 4305methods:306307- `__init__(self)`, in which you will create state variables for your metric.308- `update_state(self, y_true, y_pred, sample_weight=None)`, which uses the targets309y_true and the model predictions y_pred to update the state variables.310- `result(self)`, which uses the state variables to compute the final results.311- `reset_state(self)`, which reinitializes the state of the metric.312313State update and results computation are kept separate (in `update_state()` and314`result()`, respectively) because in some cases, the results computation might be very315expensive and would only be done periodically.316317Here's a simple example showing how to implement a `CategoricalTruePositives` metric318that counts how many samples were correctly classified as belonging to a given class:319"""320321322class CategoricalTruePositives(keras.metrics.Metric):323def __init__(self, name="categorical_true_positives", **kwargs):324super().__init__(name=name, **kwargs)325self.true_positives = self.add_variable(326shape=(), name="ctp", initializer="zeros"327)328329def update_state(self, y_true, y_pred, sample_weight=None):330y_pred = ops.reshape(ops.argmax(y_pred, axis=1), (-1, 1))331values = ops.cast(y_true, "int32") == ops.cast(y_pred, "int32")332values = ops.cast(values, "float32")333if sample_weight is not None:334sample_weight = ops.cast(sample_weight, "float32")335values = ops.multiply(values, sample_weight)336self.true_positives.assign_add(ops.sum(values))337338def result(self):339return self.true_positives.value340341def reset_state(self):342# The state of the metric will be reset at the start of each epoch.343self.true_positives.assign(0.0)344345346model = get_uncompiled_model()347model.compile(348optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),349loss=keras.losses.SparseCategoricalCrossentropy(),350metrics=[CategoricalTruePositives()],351)352model.fit(x_train, y_train, batch_size=64, epochs=3)353354"""355### Handling losses and metrics that don't fit the standard signature356357The overwhelming majority of losses and metrics can be computed from `y_true` and358`y_pred`, where `y_pred` is an output of your model -- but not all of them. For359instance, a regularization loss may only require the activation of a layer (there are360no targets in this case), and this activation may not be a model output.361362In such cases, you can call `self.add_loss(loss_value)` from inside the call method of363a custom layer. Losses added in this way get added to the "main" loss during training364(the one passed to `compile()`). Here's a simple example that adds activity365regularization (note that activity regularization is built-in in all Keras layers --366this layer is just for the sake of providing a concrete example):367"""368369370class ActivityRegularizationLayer(layers.Layer):371def call(self, inputs):372self.add_loss(ops.sum(inputs) * 0.1)373return inputs # Pass-through layer.374375376inputs = keras.Input(shape=(784,), name="digits")377x = layers.Dense(64, activation="relu", name="dense_1")(inputs)378379# Insert activity regularization as a layer380x = ActivityRegularizationLayer()(x)381382x = layers.Dense(64, activation="relu", name="dense_2")(x)383outputs = layers.Dense(10, name="predictions")(x)384385model = keras.Model(inputs=inputs, outputs=outputs)386model.compile(387optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),388loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),389)390391# The displayed loss will be much higher than before392# due to the regularization component.393model.fit(x_train, y_train, batch_size=64, epochs=1)394395"""396Note that when you pass losses via `add_loss()`, it becomes possible to call397`compile()` without a loss function, since the model already has a loss to minimize.398399Consider the following `LogisticEndpoint` layer: it takes as inputs400targets & logits, and it tracks a crossentropy loss via `add_loss()`.401"""402403404class LogisticEndpoint(keras.layers.Layer):405def __init__(self, name=None):406super().__init__(name=name)407self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)408409def call(self, targets, logits, sample_weights=None):410# Compute the training-time loss value and add it411# to the layer using `self.add_loss()`.412loss = self.loss_fn(targets, logits, sample_weights)413self.add_loss(loss)414415# Return the inference-time prediction tensor (for `.predict()`).416return ops.softmax(logits)417418419"""420You can use it in a model with two inputs (input data & targets), compiled without a421`loss` argument, like this:422"""423424inputs = keras.Input(shape=(3,), name="inputs")425targets = keras.Input(shape=(10,), name="targets")426logits = keras.layers.Dense(10)(inputs)427predictions = LogisticEndpoint(name="predictions")(targets, logits)428429model = keras.Model(inputs=[inputs, targets], outputs=predictions)430model.compile(optimizer="adam") # No loss argument!431432data = {433"inputs": np.random.random((3, 3)),434"targets": np.random.random((3, 10)),435}436model.fit(data)437438"""439For more information about training multi-input models, see the section **Passing data440to multi-input, multi-output models**.441"""442443"""444### Automatically setting apart a validation holdout set445446In the first end-to-end example you saw, we used the `validation_data` argument to pass447a tuple of NumPy arrays `(x_val, y_val)` to the model for evaluating a validation loss448and validation metrics at the end of each epoch.449450Here's another option: the argument `validation_split` allows you to automatically451reserve part of your training data for validation. The argument value represents the452fraction of the data to be reserved for validation, so it should be set to a number453higher than 0 and lower than 1. For instance, `validation_split=0.2` means "use 20% of454the data for validation", and `validation_split=0.6` means "use 60% of the data for455validation".456457The way the validation is computed is by taking the last x% samples of the arrays458received by the `fit()` call, before any shuffling.459460Note that you can only use `validation_split` when training with NumPy data.461"""462463model = get_compiled_model()464model.fit(x_train, y_train, batch_size=64, validation_split=0.2, epochs=1)465466"""467## Training & evaluation using `tf.data` Datasets468469In the past few paragraphs, you've seen how to handle losses, metrics, and optimizers,470and you've seen how to use the `validation_data` and `validation_split` arguments in471`fit()`, when your data is passed as NumPy arrays.472473Another option is to use an iterator-like, such as a `tf.data.Dataset`, a474PyTorch `DataLoader`, or a Keras `PyDataset`. Let's take look at the former.475476The `tf.data` API is a set of utilities in TensorFlow 2.0 for loading and preprocessing477data in a way that's fast and scalable. For a complete guide about creating `Datasets`,478see the [tf.data documentation](https://www.tensorflow.org/guide/data).479480**You can use `tf.data` to train your Keras481models regardless of the backend you're using --482whether it's JAX, PyTorch, or TensorFlow.**483You can pass a `Dataset` instance directly to the methods `fit()`, `evaluate()`, and484`predict()`:485"""486487model = get_compiled_model()488489# First, let's create a training Dataset instance.490# For the sake of our example, we'll use the same MNIST data as before.491train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))492# Shuffle and slice the dataset.493train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)494495# Now we get a test dataset.496test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))497test_dataset = test_dataset.batch(64)498499# Since the dataset already takes care of batching,500# we don't pass a `batch_size` argument.501model.fit(train_dataset, epochs=3)502503# You can also evaluate or predict on a dataset.504print("Evaluate")505result = model.evaluate(test_dataset)506dict(zip(model.metrics_names, result))507508"""509Note that the Dataset is reset at the end of each epoch, so it can be reused of the510next epoch.511512If you want to run training only on a specific number of batches from this Dataset, you513can pass the `steps_per_epoch` argument, which specifies how many training steps the514model should run using this Dataset before moving on to the next epoch.515"""516517model = get_compiled_model()518519# Prepare the training dataset520train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))521train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)522523# Only use the 100 batches per epoch (that's 64 * 100 samples)524model.fit(train_dataset, epochs=3, steps_per_epoch=100)525526"""527You can also pass a `Dataset` instance as the `validation_data` argument in `fit()`:528"""529530model = get_compiled_model()531532# Prepare the training dataset533train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))534train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)535536# Prepare the validation dataset537val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))538val_dataset = val_dataset.batch(64)539540model.fit(train_dataset, epochs=1, validation_data=val_dataset)541542"""543At the end of each epoch, the model will iterate over the validation dataset and544compute the validation loss and validation metrics.545546If you want to run validation only on a specific number of batches from this dataset,547you can pass the `validation_steps` argument, which specifies how many validation548steps the model should run with the validation dataset before interrupting validation549and moving on to the next epoch:550"""551552model = get_compiled_model()553554# Prepare the training dataset555train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))556train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)557558# Prepare the validation dataset559val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))560val_dataset = val_dataset.batch(64)561562model.fit(563train_dataset,564epochs=1,565# Only run validation using the first 10 batches of the dataset566# using the `validation_steps` argument567validation_data=val_dataset,568validation_steps=10,569)570571"""572Note that the validation dataset will be reset after each use (so that you will always573be evaluating on the same samples from epoch to epoch).574575The argument `validation_split` (generating a holdout set from the training data) is576not supported when training from `Dataset` objects, since this feature requires the577ability to index the samples of the datasets, which is not possible in general with578the `Dataset` API.579"""580581"""582## Training & evaluation using `PyDataset` instances583584`keras.utils.PyDataset` is a utility that you can subclass to obtain585a Python generator with two important properties:586587- It works well with multiprocessing.588- It can be shuffled (e.g. when passing `shuffle=True` in `fit()`).589590A `PyDataset` must implement two methods:591592- `__getitem__`593- `__len__`594595The method `__getitem__` should return a complete batch.596If you want to modify your dataset between epochs, you may implement `on_epoch_end`.597598Here's a quick example:599"""600601602class ExamplePyDataset(keras.utils.PyDataset):603def __init__(self, x, y, batch_size, **kwargs):604super().__init__(**kwargs)605self.x = x606self.y = y607self.batch_size = batch_size608609def __len__(self):610return int(np.ceil(len(self.x) / float(self.batch_size)))611612def __getitem__(self, idx):613batch_x = self.x[idx * self.batch_size : (idx + 1) * self.batch_size]614batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]615return batch_x, batch_y616617618train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32)619val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32)620621"""622To fit the model, pass the dataset instead as the `x` argument (no need for a `y`623argument since the dataset includes the targets), and pass the validation dataset624as the `validation_data` argument. And no need for the `batch_size` argument, since625the dataset is already batched!626"""627628model = get_compiled_model()629model.fit(train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1)630631"""632Evaluating the model is just as easy:633"""634635model.evaluate(val_py_dataset)636637"""638Importantly, `PyDataset` objects support three common constructor arguments639that handle the parallel processing configuration:640641- `workers`: Number of workers to use in multithreading or642multiprocessing. Typically, you'd set it to the number of643cores on your CPU.644- `use_multiprocessing`: Whether to use Python multiprocessing for645parallelism. Setting this to `True` means that your646dataset will be replicated in multiple forked processes.647This is necessary to gain compute-level (rather than I/O level)648benefits from parallelism. However it can only be set to649`True` if your dataset can be safely pickled.650- `max_queue_size`: Maximum number of batches to keep in the queue651when iterating over the dataset in a multithreaded or652multipricessed setting.653You can reduce this value to reduce the CPU memory consumption of654your dataset. It defaults to 10.655656By default, multiprocessing is disabled (`use_multiprocessing=False`) and only657one thread is used. You should make sure to only turn on `use_multiprocessing` if658your code is running inside a Python `if __name__ == "__main__":` block in order659to avoid issues.660661Here's a 4-thread, non-multiprocessed example:662"""663664train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32, workers=4)665val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32, workers=4)666667model = get_compiled_model()668model.fit(train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1)669670"""671## Training & evaluation using PyTorch `DataLoader` objects672673All built-in training and evaluation APIs are also compatible with `torch.utils.data.Dataset` and674`torch.utils.data.DataLoader` objects -- regardless of whether you're using the PyTorch backend,675or the JAX or TensorFlow backends. Let's take a look at a simple example.676677Unlike `PyDataset` which are batch-centric, PyTorch `Dataset` objects are sample-centric:678the `__len__` method returns the number of samples,679and the `__getitem__` method returns a specific sample.680"""681682683class ExampleTorchDataset(torch.utils.data.Dataset):684def __init__(self, x, y):685self.x = x686self.y = y687688def __len__(self):689return len(self.x)690691def __getitem__(self, idx):692return self.x[idx], self.y[idx]693694695train_torch_dataset = ExampleTorchDataset(x_train, y_train)696val_torch_dataset = ExampleTorchDataset(x_val, y_val)697698"""699To use a PyTorch Dataset, you need to wrap it into a `Dataloader` which takes care700of batching and shuffling:701"""702703train_dataloader = torch.utils.data.DataLoader(704train_torch_dataset, batch_size=32, shuffle=True705)706val_dataloader = torch.utils.data.DataLoader(707val_torch_dataset, batch_size=32, shuffle=True708)709710"""711Now you can use them in the Keras API just like any other iterator:712"""713714model = get_compiled_model()715model.fit(train_dataloader, batch_size=64, validation_data=val_dataloader, epochs=1)716model.evaluate(val_dataloader)717718"""719## Using sample weighting and class weighting720721With the default settings the weight of a sample is decided by its frequency722in the dataset. There are two methods to weight the data, independent of723sample frequency:724725* Class weights726* Sample weights727"""728729"""730### Class weights731732This is set by passing a dictionary to the `class_weight` argument to733`Model.fit()`. This dictionary maps class indices to the weight that should734be used for samples belonging to this class.735736This can be used to balance classes without resampling, or to train a737model that gives more importance to a particular class.738739For instance, if class "0" is half as represented as class "1" in your data,740you could use `Model.fit(..., class_weight={0: 1., 1: 0.5})`.741"""742743"""744Here's a NumPy example where we use class weights or sample weights to745give more importance to the correct classification of class #5 (which746is the digit "5" in the MNIST dataset).747"""748749class_weight = {7500: 1.0,7511: 1.0,7522: 1.0,7533: 1.0,7544: 1.0,755# Set weight "2" for class "5",756# making this class 2x more important7575: 2.0,7586: 1.0,7597: 1.0,7608: 1.0,7619: 1.0,762}763764print("Fit with class weight")765model = get_compiled_model()766model.fit(x_train, y_train, class_weight=class_weight, batch_size=64, epochs=1)767768"""769### Sample weights770771For fine grained control, or if you are not building a classifier,772you can use "sample weights".773774- When training from NumPy data: Pass the `sample_weight`775argument to `Model.fit()`.776- When training from `tf.data` or any other sort of iterator:777Yield `(input_batch, label_batch, sample_weight_batch)` tuples.778779A "sample weights" array is an array of numbers that specify how much weight780each sample in a batch should have in computing the total loss. It is commonly781used in imbalanced classification problems (the idea being to give more weight782to rarely-seen classes).783784When the weights used are ones and zeros, the array can be used as a *mask* for785the loss function (entirely discarding the contribution of certain samples to786the total loss).787"""788789sample_weight = np.ones(shape=(len(y_train),))790sample_weight[y_train == 5] = 2.0791792print("Fit with sample weight")793model = get_compiled_model()794model.fit(x_train, y_train, sample_weight=sample_weight, batch_size=64, epochs=1)795796"""797Here's a matching `Dataset` example:798"""799800sample_weight = np.ones(shape=(len(y_train),))801sample_weight[y_train == 5] = 2.0802803# Create a Dataset that includes sample weights804# (3rd element in the return tuple).805train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))806807# Shuffle and slice the dataset.808train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)809810model = get_compiled_model()811model.fit(train_dataset, epochs=1)812813"""814## Passing data to multi-input, multi-output models815816In the previous examples, we were considering a model with a single input (a tensor of817shape `(764,)`) and a single output (a prediction tensor of shape `(10,)`). But what818about models that have multiple inputs or outputs?819820Consider the following model, which has an image input of shape `(32, 32, 3)` (that's821`(height, width, channels)`) and a time series input of shape `(None, 10)` (that's822`(timesteps, features)`). Our model will have two outputs computed from the823combination of these inputs: a "score" (of shape `(1,)`) and a probability824distribution over five classes (of shape `(5,)`).825"""826827image_input = keras.Input(shape=(32, 32, 3), name="img_input")828timeseries_input = keras.Input(shape=(None, 10), name="ts_input")829830x1 = layers.Conv2D(3, 3)(image_input)831x1 = layers.GlobalMaxPooling2D()(x1)832833x2 = layers.Conv1D(3, 3)(timeseries_input)834x2 = layers.GlobalMaxPooling1D()(x2)835836x = layers.concatenate([x1, x2])837838score_output = layers.Dense(1, name="score_output")(x)839class_output = layers.Dense(5, name="class_output")(x)840841model = keras.Model(842inputs=[image_input, timeseries_input], outputs=[score_output, class_output]843)844845"""846Let's plot this model, so you can clearly see what we're doing here (note that the847shapes shown in the plot are batch shapes, rather than per-sample shapes).848"""849850keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)851852"""853At compilation time, we can specify different losses to different outputs, by passing854the loss functions as a list:855"""856857model.compile(858optimizer=keras.optimizers.RMSprop(1e-3),859loss=[860keras.losses.MeanSquaredError(),861keras.losses.CategoricalCrossentropy(),862],863)864865"""866If we only passed a single loss function to the model, the same loss function would be867applied to every output (which is not appropriate here).868869Likewise for metrics:870"""871872model.compile(873optimizer=keras.optimizers.RMSprop(1e-3),874loss=[875keras.losses.MeanSquaredError(),876keras.losses.CategoricalCrossentropy(),877],878metrics=[879[880keras.metrics.MeanAbsolutePercentageError(),881keras.metrics.MeanAbsoluteError(),882],883[keras.metrics.CategoricalAccuracy()],884],885)886887"""888Since we gave names to our output layers, we could also specify per-output losses and889metrics via a dict:890"""891892model.compile(893optimizer=keras.optimizers.RMSprop(1e-3),894loss={895"score_output": keras.losses.MeanSquaredError(),896"class_output": keras.losses.CategoricalCrossentropy(),897},898metrics={899"score_output": [900keras.metrics.MeanAbsolutePercentageError(),901keras.metrics.MeanAbsoluteError(),902],903"class_output": [keras.metrics.CategoricalAccuracy()],904},905)906907"""908We recommend the use of explicit names and dicts if you have more than 2 outputs.909910It's possible to give different weights to different output-specific losses (for911instance, one might wish to privilege the "score" loss in our example, by giving to 2x912the importance of the class loss), using the `loss_weights` argument:913"""914915model.compile(916optimizer=keras.optimizers.RMSprop(1e-3),917loss={918"score_output": keras.losses.MeanSquaredError(),919"class_output": keras.losses.CategoricalCrossentropy(),920},921metrics={922"score_output": [923keras.metrics.MeanAbsolutePercentageError(),924keras.metrics.MeanAbsoluteError(),925],926"class_output": [keras.metrics.CategoricalAccuracy()],927},928loss_weights={"score_output": 2.0, "class_output": 1.0},929)930931"""932You could also choose not to compute a loss for certain outputs, if these outputs are933meant for prediction but not for training:934"""935936# List loss version937model.compile(938optimizer=keras.optimizers.RMSprop(1e-3),939loss=[None, keras.losses.CategoricalCrossentropy()],940)941942# Or dict loss version943model.compile(944optimizer=keras.optimizers.RMSprop(1e-3),945loss={"class_output": keras.losses.CategoricalCrossentropy()},946)947948"""949Passing data to a multi-input or multi-output model in `fit()` works in a similar way as950specifying a loss function in compile: you can pass **lists of NumPy arrays** (with9511:1 mapping to the outputs that received a loss function) or **dicts mapping output952names to NumPy arrays**.953"""954955model.compile(956optimizer=keras.optimizers.RMSprop(1e-3),957loss=[958keras.losses.MeanSquaredError(),959keras.losses.CategoricalCrossentropy(),960],961)962963# Generate dummy NumPy data964img_data = np.random.random_sample(size=(100, 32, 32, 3))965ts_data = np.random.random_sample(size=(100, 20, 10))966score_targets = np.random.random_sample(size=(100, 1))967class_targets = np.random.random_sample(size=(100, 5))968969# Fit on lists970model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)971972# Alternatively, fit on dicts973model.fit(974{"img_input": img_data, "ts_input": ts_data},975{"score_output": score_targets, "class_output": class_targets},976batch_size=32,977epochs=1,978)979980"""981Here's the `Dataset` use case: similarly as what we did for NumPy arrays, the `Dataset`982should return a tuple of dicts.983"""984985train_dataset = tf.data.Dataset.from_tensor_slices(986(987{"img_input": img_data, "ts_input": ts_data},988{"score_output": score_targets, "class_output": class_targets},989)990)991train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)992993model.fit(train_dataset, epochs=1)994995"""996## Using callbacks997998Callbacks in Keras are objects that are called at different points during training (at999the start of an epoch, at the end of a batch, at the end of an epoch, etc.). They1000can be used to implement certain behaviors, such as:10011002- Doing validation at different points during training (beyond the built-in per-epoch1003validation)1004- Checkpointing the model at regular intervals or when it exceeds a certain accuracy1005threshold1006- Changing the learning rate of the model when training seems to be plateauing1007- Doing fine-tuning of the top layers when training seems to be plateauing1008- Sending email or instant message notifications when training ends or where a certain1009performance threshold is exceeded1010- Etc.10111012Callbacks can be passed as a list to your call to `fit()`:1013"""10141015model = get_compiled_model()10161017callbacks = [1018keras.callbacks.EarlyStopping(1019# Stop training when `val_loss` is no longer improving1020monitor="val_loss",1021# "no longer improving" being defined as "no better than 1e-2 less"1022min_delta=1e-2,1023# "no longer improving" being further defined as "for at least 2 epochs"1024patience=2,1025verbose=1,1026)1027]1028model.fit(1029x_train,1030y_train,1031epochs=20,1032batch_size=64,1033callbacks=callbacks,1034validation_split=0.2,1035)10361037"""1038### Many built-in callbacks are available10391040There are many built-in callbacks already available in Keras, such as:10411042- `ModelCheckpoint`: Periodically save the model.1043- `EarlyStopping`: Stop training when training is no longer improving the validation1044metrics.1045- `TensorBoard`: periodically write model logs that can be visualized in1046[TensorBoard](https://www.tensorflow.org/tensorboard) (more details in the section1047"Visualization").1048- `CSVLogger`: streams loss and metrics data to a CSV file.1049- etc.10501051See the [callbacks documentation](/api/callbacks/) for the complete list.10521053### Writing your own callback10541055You can create a custom callback by extending the base class1056`keras.callbacks.Callback`. A callback has access to its associated model through the1057class property `self.model`.10581059Make sure to read the1060[complete guide to writing custom callbacks](/guides/writing_your_own_callbacks/).10611062Here's a simple example saving a list of per-batch loss values during training:1063"""106410651066class LossHistory(keras.callbacks.Callback):1067def on_train_begin(self, logs):1068self.per_batch_losses = []10691070def on_batch_end(self, batch, logs):1071self.per_batch_losses.append(logs.get("loss"))107210731074"""1075## Checkpointing models10761077When you're training model on relatively large datasets, it's crucial to save1078checkpoints of your model at frequent intervals.10791080The easiest way to achieve this is with the `ModelCheckpoint` callback:1081"""10821083model = get_compiled_model()10841085callbacks = [1086keras.callbacks.ModelCheckpoint(1087# Path where to save the model1088# The two parameters below mean that we will overwrite1089# the current checkpoint if and only if1090# the `val_loss` score has improved.1091# The saved model name will include the current epoch.1092filepath="mymodel_{epoch}.keras",1093save_best_only=True, # Only save a model if `val_loss` has improved.1094monitor="val_loss",1095verbose=1,1096)1097]1098model.fit(1099x_train,1100y_train,1101epochs=2,1102batch_size=64,1103callbacks=callbacks,1104validation_split=0.2,1105)11061107"""1108The `ModelCheckpoint` callback can be used to implement fault-tolerance:1109the ability to restart training from the last saved state of the model in case training1110gets randomly interrupted. Here's a basic example:1111"""11121113# Prepare a directory to store all the checkpoints.1114checkpoint_dir = "./ckpt"1115if not os.path.exists(checkpoint_dir):1116os.makedirs(checkpoint_dir)111711181119def make_or_restore_model():1120# Either restore the latest model, or create a fresh one1121# if there is no checkpoint available.1122checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]1123if checkpoints:1124latest_checkpoint = max(checkpoints, key=os.path.getctime)1125print("Restoring from", latest_checkpoint)1126return keras.models.load_model(latest_checkpoint)1127print("Creating a new model")1128return get_compiled_model()112911301131model = make_or_restore_model()1132callbacks = [1133# This callback saves the model every 100 batches.1134# We include the training loss in the saved model name.1135keras.callbacks.ModelCheckpoint(1136filepath=checkpoint_dir + "/model-loss={loss:.2f}.keras", save_freq=1001137)1138]1139model.fit(x_train, y_train, epochs=1, callbacks=callbacks)11401141"""1142You call also write your own callback for saving and restoring models.11431144For a complete guide on serialization and saving, see the1145[guide to saving and serializing Models](/guides/serialization_and_saving/).1146"""11471148"""1149## Using learning rate schedules11501151A common pattern when training deep learning models is to gradually reduce the learning1152as training progresses. This is generally known as "learning rate decay".11531154The learning decay schedule could be static (fixed in advance, as a function of the1155current epoch or the current batch index), or dynamic (responding to the current1156behavior of the model, in particular the validation loss).11571158### Passing a schedule to an optimizer11591160You can easily use a static learning rate decay schedule by passing a schedule object1161as the `learning_rate` argument in your optimizer:1162"""11631164initial_learning_rate = 0.11165lr_schedule = keras.optimizers.schedules.ExponentialDecay(1166initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True1167)11681169optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)11701171"""1172Several built-in schedules are available: `ExponentialDecay`, `PiecewiseConstantDecay`,1173`PolynomialDecay`, and `InverseTimeDecay`.11741175### Using callbacks to implement a dynamic learning rate schedule11761177A dynamic learning rate schedule (for instance, decreasing the learning rate when the1178validation loss is no longer improving) cannot be achieved with these schedule objects,1179since the optimizer does not have access to validation metrics.11801181However, callbacks do have access to all metrics, including validation metrics! You can1182thus achieve this pattern by using a callback that modifies the current learning rate1183on the optimizer. In fact, this is even built-in as the `ReduceLROnPlateau` callback.1184"""11851186"""1187## Visualizing loss and metrics during training with TensorBoard11881189The best way to keep an eye on your model during training is to use1190[TensorBoard](https://www.tensorflow.org/tensorboard) -- a browser-based application1191that you can run locally that provides you with:11921193- Live plots of the loss and metrics for training and evaluation1194- (optionally) Visualizations of the histograms of your layer activations1195- (optionally) 3D visualizations of the embedding spaces learned by your `Embedding`1196layers11971198If you have installed TensorFlow with pip, you should be able to launch TensorBoard1199from the command line:12001201```1202tensorboard --logdir=/full_path_to_your_logs1203```1204"""12051206"""1207### Using the TensorBoard callback12081209The easiest way to use TensorBoard with a Keras model and the `fit()` method is the1210`TensorBoard` callback.12111212In the simplest case, just specify where you want the callback to write logs, and1213you're good to go:1214"""12151216keras.callbacks.TensorBoard(1217log_dir="/full_path_to_your_logs",1218histogram_freq=0, # How often to log histogram visualizations1219embeddings_freq=0, # How often to log embedding visualizations1220update_freq="epoch",1221) # How often to write logs (default: once per epoch)12221223"""1224For more information, see the1225[documentation for the `TensorBoard` callback](https://keras.io/api/callbacks/tensorboard/).1226"""122712281229