Path: blob/master/site/en-snapshot/lite/examples/on_device_training/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
On-Device Training with TensorFlow Lite
When deploying TensorFlow Lite machine learning model to device or mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model without data leaving your users' devices, improving user privacy, and without requiring users to update the device software.
For example, you may have a model in your mobile app that recognizes fashion items, but you want users to get improved recognition performance over time based on their interests. Enabling on-device training allows users who are interested in shoes to get better at recognizing a particular style of shoe or shoe brand the more often they use your app.
This tutorial shows you how to construct a TensorFlow Lite model that can be incrementally trained and improved within an installed Android app.
Note: The on-device training technique can be added to existing TensorFlow Lite implementations, provided the devices you are targeting support local file storage.
Setup
This tutorial uses Python to train and convert a TensorFlow model before incorporating it into an Android app. Get started by installing and importing the following packages.
Note: The On-Device Training APIs are available in TensorFlow version 2.7 and higher.
Classify images of clothing
This example code uses the Fashion MNIST dataset to train a neural network model for classifying images of clothing. This dataset contains 60,000 small (28 x 28 pixel) grayscale images containing 10 different categories of fashion accessories, including dresses, shirts, and sandals.

You can explore this dataset in more depth in the Keras classification tutorial.
Build a model for on-device training
TensorFlow Lite models typically have only a single exposed function method (or signature) that allows you to call the model to run an inference. For a model to be trained and used on a device, you must be able to perform several separate operations, including train, infer, save, and restore functions for the model. You can enable this functionality by first extending your TensorFlow model to have multiple functions, and then exposing those functions as signatures when you convert your model to the TensorFlow Lite model format.
The code example below shows you how to add the following functions to a TensorFlow model:
train
function trains the model with training data.infer
function invokes the inference.save
function saves the trainable weights into the file system.restore
function loads the trainable weights from the file system.
The train
function in the code above uses the GradientTape class to record operations for automatic differentiation. For more information on how to use this class, see the Introduction to gradients and automatic differentiation.
You could use the Model.train_step
method of the keras model here instead of a from-scratch implementation. Just note that the loss (and metrics) returned by Model.train_step
is the running average, and should be reset regularly (typically each epoch). See Customize Model.fit for details.
Note: The weights generated by this model are serialized into a TensorFlow 1 format checkpoint file.
Prepare the data
Get the Fashion MNIST dataset for training your model.
Preprocess the dataset
Pixel values in this dataset are between 0 and 255, and must be normalized to a value between 0 and 1 for processing by the model. Divide the values by 255 to make this adjustment.
Convert the data labels to categorical values by performing one-hot encoding.
Note: Make sure you preprocess your training and testing datasets in the same way, so that your testing accurately evaluate your model's performance.
Train the model
Before converting and setting up your TensorFlow Lite model, complete the initial training of your model using the preprocessed dataset and the train
signature method. The following code runs model training for 100 epochs, processing batches of 100 images at a time, and displaying the loss value after every 10 epochs. Since this training run is processing quite a bit of data, it may take a few minutes to finish.
Note: You should complete initial training of your model before converting it to TensorFlow Lite format, so that the model has an initial set of weights, and is able to perform reasonable inferences before you start collecting data and conducting training runs on the device.
Convert model to TensorFlow Lite format
After you have extended your TensorFlow model to enable additional functions for on-device training and completed initial training of the model, you can convert it to TensorFlow Lite format. The following code converts and saves your model to that format, including the set of signatures that you use with the TensorFlow Lite model on a device: train, infer, save, restore
.
Setup the TensorFlow Lite signatures
The TensorFlow Lite model you saved in the previous step contains several function signatures. You can access them through the tf.lite.Interpreter
class and invoke each restore
, train
, save
, and infer
signature separately.
Compare the output of the original model, and the converted lite model:
Above, you can see that the behavior of the model is not changed by the conversion to TFLite.
Retrain the model on a device
After converting your model to TensorFlow Lite and deploying it with your app, you can retrain the model on a device using new data and the train
signature method of your model. Each training run generates a new set of weights that you can save for re-use and further improvement of the model, as shown in the next section.
Note: Since training tasks are resource intensive, you should consider performing them when users are not actively interacting with the device, and as a background process. Consider using the WorkManager API to schedule model retraining as an asynchronous task.
On Android, you can perform on-device training with TensorFlow Lite using either Java or C++ APIs. In Java, use the Interpreter
class to load a model and drive model training tasks. The following example shows how to run the training procedure using the runSignature
method:
You can see a complete code example of model retraining inside an Android app in the model personalization demo app.
Run training for a few epochs to improve or personalize the model. In practice, you would run this additional training using data collected on the device. For simplicity, this example uses the same training data as the previous training step.
Above you can see that the on-device training picks up exactly where the pretraining stopped.
Save the trained weights
When you complete a training run on a device, the model updates the set of weights it is using in memory. Using the save
signature method you created in your TensorFlow Lite model, you can save these weights to a checkpoint file for later reuse and improve your model.
In your Android application, you can store the generated weights as a checkpoint file in the internal storage space allocated for your app.
Restore the trained weights
Any time you create an interpreter from a TFLite model, the interpreter will initially load the original model weights.
So after you've done some training and saved a checkpoint file, you'll need to run the restore
signature method to load the checkpoint.
A good rule is "Anytime you create an Interpreter for a model, if the checkpoint exists, load it". If you need to reset the model to the baseline behavior, just delete the checkpoint and create a fresh interpreter.
The checkpoint was generated by training and saving with TFLite. Above you can see that applying the checkpoint updates the behavior of the model.
Note: Loading the saved weights from the checkpoint can take time, based on the number of variables in the model and the size of the checkpoint file.
In your Android app, you can restore the serialized, trained weights from the checkpoint file you stored earlier.
Note: When your application restarts, you should reload your trained weights prior to running new inferences.
Run Inference using trained weights
Once you have loaded previously saved weights from a checkpoint file, running the infer
method uses those weights with your original model to improve predictions. After loading the saved weights, you can use the infer
signature method as shown below.
Note: Loading the saved weights is not required to run an inference, but running in that configuration produces predictions using the originally trained model, without improvements.
Plot the predicted labels.
In your Android application, after restoring the trained weights, run the inferences based on the loaded data.
Congratulations! You now have built a TensorFlow Lite model that supports on-device training. For more coding details, check out the example implementation in the model personalization demo app.
If you are interested in learning more about image classification, check Keras classification tutorial in the TensorFlow official guide page. This tutorial is based on that exercise and provides more depth on the subject of classification.