Path: blob/master/site/zh-cn/lite/examples/jax_conversion/overview.ipynb
25118 views
Kernel: Python 3
Copyright 2021 The TensorFlow Authors.
In [ ]:
用于 TFLite 的 Jax 模型转换
概述
注:此为新 API ,只有通过 pip 安装 tf-nighly 才能使用。它将在 TensorFlow 2.7 版中提供。另外,此 API 仍处于实验阶段,可能会发生变化。
此 CodeLab 演示了如何使用 Jax 构建 MNIST 识别模型,以及如何将其转换为 TensorFlow Lite。此 CodeLab 还将演示如何使用训练后量化来优化 Jax 转换的 TFLite 模型。
先决条件
建议在最新的 TensorFlow nightly pip 构建中尝试此功能。
In [ ]:
数据准备
使用 Keras 数据集下载 MNIST 数据并进行预处理。
In [ ]:
In [ ]:
使用 Jax 构建 MNIST 模型
In [ ]:
训练并评估模型
In [ ]:
转换为 TFLite 模型
请注意,我们需要执行以下操作:
使用
functools.partial
将参数内联到 Jaxpredict
函数。构建一个
jnp.zeros
,这是一个用于 Jax 跟踪模型的“占位符”张量。调用
experimental_from_jax
:
serving_func
被封装在一个列表中。输入与给定的名称相关联,并作为封装在列表中的数组传入。
In [ ]:
检查转换后的 TFLite 模型
将转换后的模型的结果与 Jax 模型进行比较。
In [ ]:
优化模型
我们将提供一个 representative_dataset
来进行训练后量化,以优化模型。
In [ ]:
评估优化后的模型
In [ ]:
比较量化模型大小
我们应该能够看到,量化模型的大小缩减为了原始模型的四分之一。
In [ ]: