Path: blob/master/site/en-snapshot/guide/jax2tf.ipynb
25115 views
Copyright 2023 The TensorFlow Authors.
Import a JAX model using JAX2TF
This notebook provides a complete, runnable example of creating a model using JAX and bringing it into TensorFlow to continue training. This is made possible by JAX2TF, a lightweight API that provides a pathway from the JAX ecosystem to the TensorFlow ecosystem.
JAX is a high-performance array computing library. To create the model, this notebook uses Flax, a neural network library for JAX. To train it, it uses Optax, an optimization library for JAX.
If you're a researcher using JAX, JAX2TF gives you a path to production using TensorFlow's proven tools.
There are many ways this can be useful, here are just a few:
Inference: Taking a model written for JAX and deploying it either on a server using TF Serving, on-device using TFLite, or on the web using TensorFlow.js.
Fine-tuning: Taking a model that was trained using JAX, you can bring its components to TF using JAX2TF, and continue training it in TensorFlow with your existing training data and setup.
Fusion: Combining parts of models that were trained using JAX with those trained using TensorFlow, for maximum flexibility.
The key to enabling this kind of interoperation between JAX and TensorFlow is jax2tf.convert
, which takes in model components created on top of JAX (your loss function, prediction function, etc) and creates equivalent representations of them as TensorFlow functions, which can then be exported as a TensorFlow SavedModel.
Setup
Download and prepare the MNIST dataset
Configure training
This notebook will create and train a simple model for demonstration purposes.
Create the model using Flax
Write the training step function
Write the training loop
Create the model and the optimizer (with Optax)
Train the model
Partially train the model
You will continue training the model in TensorFlow shortly.
Save just enough for inference
If your goal is to deploy your JAX model (so you can run inference using model.predict()
), simply exporting it to SavedModel is sufficient. This section demonstrates how to accomplish that.
Save everything
If your goal is a comprehensive export (useful if you're planning on brining the model into TensorFlow for fine-tuning, fusion, etc), this section demonstrates how to save the model so you can access methods including:
model.predict
model.accuracy
model.loss (including train=True/False bool, RNG for dropout and BatchNorm state updates)
Reload the model
Continue training the converted JAX model in TensorFlow
Next steps
You can learn more about JAX and Flax on their documentation websites which contain detailed guides and examples. If you're new to JAX, be sure to explore the JAX 101 tutorials, and check out the Flax quickstart. To learn more about converting JAX models to TensorFlow format, check out the jax2tf utility on GitHub. If you're interested in converting JAX models to run in the browser with TensorFlow.js, visit JAX on the Web with TensorFlow.js. If you'd like to prepare JAX models to run in TensorFLow Lite, visit the JAX Model Conversion For TFLite guide.