Path: blob/master/site/en-snapshot/model_optimization/guide/pruning/comprehensive_guide.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
Pruning comprehensive guide
Welcome to the comprehensive guide for Keras weight pruning.
This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the API docs.
If you want to see the benefits of pruning and what's supported, see the overview.
For a single end-to-end example, see the pruning example.
The following use cases are covered:
Define and train a pruned model.
Sequential and Functional.
Keras model.fit and custom training loops
Checkpoint and deserialize a pruned model.
Deploy a pruned model and see compression benefits.
For configuration of the pruning algorithm, refer to the tfmot.sparsity.keras.prune_low_magnitude
API docs.
Setup
For finding the APIs you need and understanding purposes, you can run but skip reading this section.
Define model
Prune whole model (Sequential and Functional)
Tips for better model accuracy:
Try "Prune some layers" to skip pruning the layers that reduce accuracy the most.
It's generally better to finetune with pruning as opposed to training from scratch.
To make the whole model train with pruning, apply tfmot.sparsity.keras.prune_low_magnitude
to the model.
Prune some layers (Sequential and Functional)
Pruning a model can have a negative effect on accuracy. You can selectively prune layers of a model to explore the trade-off between accuracy, speed, and model size.
Tips for better model accuracy:
It's generally better to finetune with pruning as opposed to training from scratch.
Try pruning the later layers instead of the first layers.
Avoid pruning critical layers (e.g. attention mechanism).
More:
The
tfmot.sparsity.keras.prune_low_magnitude
API docs provide details on how to vary the pruning configuration per layer.
In the example below, prune only the Dense
layers.
While this example used the type of the layer to decide what to prune, the easiest way to prune a particular layer is to set its name
property, and look for that name in the clone_function
.
More readable but potentially lower model accuracy
This is not compatible with fine-tuning with pruning, which is why it may be less accurate than the above examples which support fine-tuning.
While prune_low_magnitude
can be applied while defining the initial model, loading the weights after does not work in the below examples.
Functional example
Sequential example
Prune custom Keras layer or modify parts of layer to prune
Common mistake: pruning the bias usually harms model accuracy too much.
tfmot.sparsity.keras.PrunableLayer
serves two use cases:
Prune a custom Keras layer
Modify parts of a built-in Keras layer to prune.
For an example, the API defaults to only pruning the kernel of the Dense
layer. The example below prunes the bias also.
Train model
Model.fit
Call the tfmot.sparsity.keras.UpdatePruningStep
callback during training.
To help debug training, use the tfmot.sparsity.keras.PruningSummaries
callback.
For non-Colab users, you can see the results of a previous run of this code block on TensorBoard.dev.
Custom training loop
Call the tfmot.sparsity.keras.UpdatePruningStep
callback during training.
To help debug training, use the tfmot.sparsity.keras.PruningSummaries
callback.
For non-Colab users, you can see the results of a previous run of this code block on TensorBoard.dev.
Improve pruned model accuracy
First, look at the tfmot.sparsity.keras.prune_low_magnitude
API docs to understand what a pruning schedule is and the math of each type of pruning schedule.
Tips:
Have a learning rate that's not too high or too low when the model is pruning. Consider the pruning schedule to be a hyperparameter.
As a quick test, try experimenting with pruning a model to the final sparsity at the begining of training by setting
begin_step
to 0 with atfmot.sparsity.keras.ConstantSparsity
schedule. You might get lucky with good results.Do not prune very frequently to give the model time to recover. The pruning schedule provides a decent default frequency.
For general ideas to improve model accuracy, look for tips for your use case(s) under "Define model".
Checkpoint and deserialize
You must preserve the optimizer step during checkpointing. This means while you can use Keras HDF5 models for checkpointing, you cannot use Keras HDF5 weights.
The above applies generally. The code below is only needed for the HDF5 model format (not HDF5 weights and other formats).
Deploy pruned model
Export model with size compression
Common mistake: both strip_pruning
and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of pruning.
Hardware-specific optimizations
Once different backends enable pruning to improve latency, using block sparsity can improve latency for certain hardware.
Increasing the block size will decrease the peak sparsity that's achievable for a target model accuracy. Despite this, latency can still improve.
For details on what's supported for block sparsity, see the tfmot.sparsity.keras.prune_low_magnitude
API docs.