Path: blob/master/site/ja/lite/examples/on_device_training/overview.ipynb
25118 views
Copyright 2021 The TensorFlow Authors.
TensorFlow Lite でのオンデバイストレーニング
TensorFlow Lite*{nbsp}*機械学習モデルをデバイスまたはモバイルアプリにデプロイするときには、デバイスまたはエンドユーザーの入力に基づいて、モデルを改良またはパーソナル化したい場合があります。オンデバイストレーニング手法を使用すると、データをユーザーのデバイスから移動させずにモデルを更新できます。これにより、ユーザープライバシーが強化され、ユーザーはデバイスソフトウェアを更新する必要がありません。
たとえば、モバイルアプリでファッションアイテムを認識するモデルがあり、ユーザーの関心に基づいて経時的に認識パフォーマンスを改善させたいとします。オンデバイストレーニングを有効にすると、靴に関心があるユーザーは、アプリを使用するほど、特定の靴のスタイルや靴ブランドの認識能力が高くなります。
このチュートリアルでは、インストールされた Android アプリ内で、増分的にトレーニング、改善できる TensorFlow Lite モデルを構築する方法について説明します。
注意: 対象のデバイスでローカルファイルストレージがサポートされている場合は、オンデバイストレーニング手法を既存の TensorFlow Lite 実装に追加できます。
設定
このチュートリアルでは、Python を使用して、TensorFlow モデルをトレーニング、変換します。その後に、Android アプリに統合します。まず、次のパッケージをインストールしてインポートします。
注意: On-Device Training API は TensorFlow バージョン 2.7 以上で提供されています。
服飾の画像を分類する
このサンプルコードでは、Fashion MNIST データセットを使用して、服飾の画像を分類するニュートラルネットワークモデルをトレーニングします。このデータセットには、 6 万個の小さい (28 x 28 ピクセル) グレースケール画像が含まれています。画像には、ドレス、シャツ、サンダルなどの 10 種類のファッションアクセサリのカテゴリがあります。
<figure> <img src="https://tensorflow.org/images/fashion-mnist-sprite.png" alt="Fashion MNIST images"> <figcaption><b>Figure 1</b>: <a href="https://github.com/zalandoresearch/fashion-mnist">Fashion-MNIST samples</a> (by Zalando, MIT License).</figcaption> </figure>
<figure> <img src="https://tensorflow.org/images/fashion-mnist-sprite.png" alt="Fashion MNIST images"> <figcaption><b>Figure 1</b>: <a href="https://github.com/zalandoresearch/fashion-mnist">Fashion-MNIST samples</a> (by Zalando, MIT License).</figcaption> </figure>
オンデバイストレーニングのモデルを作成する
一般的に、TensorFlow Lite モデルには、公開された関数メソッド (シグネチャ) が 1 つだけあり、それによってモデルを呼び出して推論を実行できます。デバイスでモデルをトレーニングして使用するには、モデルのトレーニング、推論、保存、復元関数といった、複数の個別の演算を実行できる必要があります。この機能を有効にするには、まず、複数の関数を使用できるように TensorFlow モデルを拡張します。次に、モデルを TensorFlow Lite モデル形式に変換するときに、これらの関数をシグネチャとして公開します。
次のコードサンプルは、次の関数を TensorFlow モデルに追加する方法について説明します。
train
関数: トレーニングデータを使用してモデルをトレーニングします。infer
関数: 推論を実行します。save
関数: トレーニング可能な重みをファイルシステムに保存します。restore
関数: トレーニング可能な重みをファイルシステムから読み込みます。
上記のコードの train
関数は GradientTape クラスを使用して、自動微分の演算を記録します。このクラスの使用方法の詳細については、勾配と自動微分の概要を参照してください。
ここでは、ゼロから実装するのではなく、keras モデルの Model.train_step
メソッドを使用できます。Model.train_step
によって返される損失 (およびメトリクス) は移動平均であり、定期的に (通常はエポックごとに) リセットしてください。詳細については、Model.fit のカスタマイズを参照してください。
注意: このモデルで生成される重みは、TensorFlow 1 形式のチェックポイントファイルにシリアル化されます。
データを準備する
モデルをトレーニングするための Fashion MNIST データセットを取得します。
データの前処理
このデータセットのピクセル値は、0 ~ 255 です。この値をモデルで処理するためには、0 ~ 1 の範囲の値に正規化する必要があります。値を 255 で除算すると、正規化できます。
ワンホットエンコーディングを実行して、データラベルをカテゴリ値に変換します。
注意:**{nbsp}**必ずトレーニングデータセットとテストデータセットは同じ方法で前処理し、テストでモデルのパフォーマンスを正確に評価できるようにしてください。
モデルのトレーニング
TensorFlow Lite モデルを変換、設定する前に、前処理済みのデータセットと train
シグネチャを使用して、モデルの初期トレーニングを完了します。次のコードは 100 エポックでモデルトレーニングを実行して、100 個の画像のバッチを一度に処理し、10 エポックごとに損失値を表示します。このトレーニング実行ではかなりのデータが処理されるため、完了するのに数分かかる場合があります。
注意: TensorFlow Lite 形式に変換する前に、モデルの初期トレーニングを完了してください。これにより、モデルに重みの初期セットが追加され、データの収集と、デバイスでのトレーニングの実行を開始する前に、合理的な推論を実行できます。
モデルを TensorFlow Lite 形式に変換する
TensorFlow モデルを拡張して、オンデバイストレーニングの追加の関数を有効にし、モデルの初期トレーニングを完了した後は、そのモデルを TensorFlow Lite モデルに変換できます。次のコードは、デバイスで Tensorflow Lite モデルを変換し、モデルとともに使用するシグネチャのセットを含む形式にモデルを保存します。train, infer, save, restore
TensorFlow Lite シグネチャを設定する
前のステップで保存した TensorFlow Lite モデルには、複数の関数シグネチャが含まれます。tf.lite.Interpreter
クラス経由でシグネチャにアクセスし、それぞれ個別に restore
、train
、save
、infer
シグネチャを呼び出すことができます。
元のモデルの出力と、変換された Lite モデルを比較します。
上記では、モデルの動作が TFLite への変換によって変わっていないことを確認できます。
デバイスでモデルを再トレーニングする
モデルを Tensorflow Lite に変換し、アプリでデプロイした後は、新しいデータとモデルの train
シグネチャ
メソッドを使用して、デバイスでモデルを再トレーニングできます。各トレーニング実行では、新しい重みのセットが生成されます。次のセクションで示すように、重みを保存すると、再利用したり、モデルのさらなる改善で使用できます。
注意: トレーニングタスクはリソースの消費量が大きいため、ユーザーがデバイスを操作していないときに実行するか、バックグラウンドプロセスとして実行することを検討してください。また、WorkManager API を使用して、非同期タスクとしてモデルの再トレーニングをスケジュールすることを検討してください。
Android では、Java API または C + + API を使用して、TensorFlow Lite でオンデバイストレーニングを実行できます。Java では、Interpreter
クラスを使用して、モデルを読み込み、モデルトレーニングタスクを実行します。次の例では、runSignature
メソッドを使用したトレーニング手順を実行する方法について示します。
モデルパーソならぜーションデモアプリでは、Android アプリ内に保持されているモデルのコードサンプル全体を確認できます。
2、3 エポック分のトレーニングを実行し、モデルを改善またはパーソナライズします。実際には、デバイスで収集されたデータを使用して、この追加トレーニングを実行してください。簡潔にするために、この例では、前のトレーニングステップと同じトレーニングデータを使用しています。
上記では、事前トレーニングが停止した正確な位置からオンデバイストレーニングが開始することを確認できます。
トレーニングされた重みの保存
デバイスでトレーニングの実行を完了すると、メモリで使用されていた重みのセットがモデルで更新されます。TensorFlow Lite モデルで作成した save
シグネチャメソッドを使用すると、これらの重みをチェックポイントファイルに保存して、後から再利用したり、モデルを改善したりできます。
Android アプリケーションでは、アプリに割り当てられた内部ストレージ領域にあるチェックポイントファイルとして、生成された重みを保存できます。
トレーニングされた重みの復元
TFLite モデルからインタープリタを作成するたびに、インタープリタでは最初に元のモデルの重みが読み込まれます。
トレーニングを実行し、チェックポイントファイルを保存した後は、restore
シグネチャメソッドを使用して、チェックポイントを読み込む必要があります。
「チェックポイントが存在する場合は、モデルのインタープリタを作成するたびに読み込む」というルールを設定しておくことをお勧めします。モデルをベースラインの動作にリセットする必要がある場合は、チェックポイントを削除し、新しい
インタープリタを作成します。
チェックポイントは、TFLite によるトレーニングと保存によって生成されます。上記では、チェックポイントを適用すると、モデルの動作が更新されることが確認できます。
注意: モデルの変数の数とチェックポイントのサイズによっては、チェックポイントから保存された重みを読み込むときに時間がかかります。
Android アプリでは、シリアル化されたトレーニング済みの重みを、前に保存したチェックポイントファイルから復元できます。
注意: アプリケーションが再起動するときには、新しい推論を実行する前に、トレーニング済みの重みを再読み込みしてください。
トレーニング済みの重みを使用した推論の実行
以前に保存した重みをチェックポイントから読み込んだ後、infer
メソッドを実行すると、これらの重みと元のモデルが使用され、予測を改善します。保存された重みを読み込んだ後は、次のように、infer
シグネチャメソッドを使用できます。
注意: 推論を実行するために保存された重みを読み込む必要はありません。ただし、その構成で実行すると、最初にトレーニングされたモデルを使用して、改善を行わずに、予測が生成されます。
予測されたラベルをプロットします。
Android アプリケーションで、トレーニング済みの重みを復元した後に、読み込まれたデータに基づいて推論を実行します。
これで、オンデバイストレーニングをサポートする TensorFlow Lite モデルを構築できました。詳細については、モデルパーソナライゼーションデモアプリの実装例を確認してください。
画像分類の詳細については、TensorFlow 公式ガイドページの Keras 分類チュートリアルを確認してください。このチュートリアルは、その演習に基づいていて、分類のテーマを掘り下げて行きます。