Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/lite/examples/jax_conversion/overview.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

TFLite向けのJaxモデル変換

概要

注意: これは新しい API であり、pip install tf-nightly 経由でのみ使用できます。また、TensorFlow バージョン 2.7 で提供される予定です。この API は実験段階であり、変更される可能性があります。

この CodeLab では、Jax を使用して MNIST 認識のモデルを構築する方法と、それを TensorFlow Lite に変換する方法について説明します。また、トレーニング後の量子化によって、Jaxに変換されたモデルを最適化する方法についても説明します。

前提条件

最新の TensorFlow nightly pip ビルドでこの機能を試すことをお勧めします。

!pip install tf-nightly --upgrade !pip install jax --upgrade !pip install jaxlib --upgrade

データの準備

MNIST データ、Keras データセット、プリプロセスをダウンロードします。

import numpy as np import tensorflow as tf import functools import time import itertools import numpy.random as npr import jax.numpy as jnp from jax import jit, grad, random from jax.example_libraries import optimizers from jax.example_libraries import stax
def _one_hot(x, k, dtype=np.float32): """Create a one-hot encoding of x of size k.""" return np.array(x[:, None] == np.arange(k), dtype) (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_images, test_images = train_images / 255.0, test_images / 255.0 train_images = train_images.astype(np.float32) test_images = test_images.astype(np.float32) train_labels = _one_hot(train_labels, 10) test_labels = _one_hot(test_labels, 10)

Jax で MNIST モデルを構築する

def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(predict(params, inputs), axis=1) return jnp.mean(predicted_class == target_class) init_random_params, predict = stax.serial( stax.Flatten, stax.Dense(1024), stax.Relu, stax.Dense(1024), stax.Relu, stax.Dense(10), stax.LogSoftmax) rng = random.PRNGKey(0)

モデルのトレーニングと評価

step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params(rng, (-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time params = get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc))

TFLite モデルに変換する

次の手順を実行します。

  1. Jax predict 関数へのパラメーターを functools.partial でインライン化します。

  2. jnp.zeros を作成します。これは Jax でモデルを追跡するために使用される「プレースホルダー」テンソルです。

  3. experimental_from_jax を呼び出します。

  • serving_func がリストでラップされます。

  • 入力は特定の名前に関連付けられ、リストでラップされた配列として渡されます。

serving_func = functools.partial(predict, params) x_input = jnp.zeros((1, 28, 28)) converter = tf.lite.TFLiteConverter.experimental_from_jax( [serving_func], [[('input1', x_input)]]) tflite_model = converter.convert() with open('jax_mnist.tflite', 'wb') as f: f.write(tflite_model)

変換された TFLite モデルを確認する

変換されたモデルの結果を Jax モデルと比較します。

expected = serving_func(train_images[0:1]) # Run the model with TensorFlow Lite interpreter = tf.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :]) interpreter.invoke() result = interpreter.get_tensor(output_details[0]["index"]) # Assert if the result of TFLite model is consistent with the JAX model. np.testing.assert_almost_equal(expected, result, 1e-5)

モデルを最適化する

モデルを最適化するために、representative_dataset を提供してトレーニング後の量子化を実行します。

def representative_dataset(): for i in range(1000): x = train_images[i:i+1] yield [x] converter = tf.lite.TFLiteConverter.experimental_from_jax( [serving_func], [[('x', x_input)]]) tflite_model = converter.convert() converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] tflite_quant_model = converter.convert() with open('jax_mnist_quant.tflite', 'wb') as f: f.write(tflite_quant_model)

最適化されたモデルを評価する

expected = serving_func(train_images[0:1]) # Run the model with TensorFlow Lite interpreter = tf.lite.Interpreter(model_content=tflite_quant_model) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :]) interpreter.invoke() result = interpreter.get_tensor(output_details[0]["index"]) # Assert if the result of TFLite model is consistent with the Jax model. np.testing.assert_almost_equal(expected, result, 1e-5)

量子化されたモデルサイズを比較する

量子化されたモデルのサイズは元のモデルの 4 分の 1 のサイズになることがわかります。

!du -h jax_mnist.tflite !du -h jax_mnist_quant.tflite