Path: blob/master/site/en-snapshot/lite/examples/jax_conversion/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
Jax Model Conversion For TFLite
Overview
Note: This API is new and only available via pip install tf-nightly. It will be available in TensorFlow version 2.7. Also, the API is still experimental and subject to changes.
This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion.
Prerequisites
It's recommended to try this feature with the newest TensorFlow nightly pip build.
Data Preparation
Download the MNIST data with Keras dataset and pre-process.
Build the MNIST model with Jax
Train & Evaluate the model
Convert to TFLite model.
Note here, we
Inline the params to the Jax
predict
func withfunctools.partial
.Build a
jnp.zeros
, this is a "placeholder" tensor used for Jax to trace the model.Call
experimental_from_jax
:
The
serving_func
is wrapped in a list.The input is associated with a given name and passed in as an array wrapped in a list.
Check the Converted TFLite Model
Compare the converted model's results with the Jax model.
Optimize the Model
We will provide a representative_dataset
to do post-training quantiztion to optimize the model.
Evaluate the Optimized Model
Compare the Quantized Model size
We should be able to see the quantized model is four times smaller than the original model.