Path: blob/master/site/ja/lite/examples/jax_conversion/overview.ipynb
25118 views
Kernel: Python 3
Copyright 2021 The TensorFlow Authors.
In [ ]:
TFLite向けのJaxモデル変換
概要
注意: これは新しい API であり、pip install tf-nightly 経由でのみ使用できます。また、TensorFlow バージョン 2.7 で提供される予定です。この API は実験段階であり、変更される可能性があります。
この CodeLab では、Jax を使用して MNIST 認識のモデルを構築する方法と、それを TensorFlow Lite に変換する方法について説明します。また、トレーニング後の量子化によって、Jaxに変換されたモデルを最適化する方法についても説明します。
前提条件
最新の TensorFlow nightly pip ビルドでこの機能を試すことをお勧めします。
In [ ]:
データの準備
MNIST データ、Keras データセット、プリプロセスをダウンロードします。
In [ ]:
In [ ]:
Jax で MNIST モデルを構築する
In [ ]:
モデルのトレーニングと評価
In [ ]:
TFLite モデルに変換する
次の手順を実行します。
Jax
predict
関数へのパラメーターをfunctools.partial
でインライン化します。jnp.zeros
を作成します。これは Jax でモデルを追跡するために使用される「プレースホルダー」テンソルです。experimental_from_jax
を呼び出します。
serving_func
がリストでラップされます。入力は特定の名前に関連付けられ、リストでラップされた配列として渡されます。
In [ ]:
変換された TFLite モデルを確認する
変換されたモデルの結果を Jax モデルと比較します。
In [ ]:
モデルを最適化する
モデルを最適化するために、representative_dataset
を提供してトレーニング後の量子化を実行します。
In [ ]:
最適化されたモデルを評価する
In [ ]:
量子化されたモデルサイズを比較する
量子化されたモデルのサイズは元のモデルの 4 分の 1 のサイズになることがわかります。
In [ ]: