Copyright 2018 The TensorFlow Authors.
トレーニングのチェックポイント
「TensorFlow のモデルを保存する」という言いまわしは通常、次の 2 つのいずれかを意味します。
チェックポイント、または
保存されたモデル(SavedModel)
チェックポイントは、モデルで使用されるすべてのパラメータ(tf.Variable
オブジェクト)の正確な値をキャプチャします。チェックポイントにはモデルで定義された計算のいかなる記述も含まれていないため、通常は、保存されたパラメータ値を使用するソースコードが利用可能な場合に限り有用です。
一方、SavedModel 形式には、パラメータ値(チェックポイント)に加え、モデルで定義された計算のシリアライズされた記述が含まれています。この形式のモデルは、モデルを作成したソースコードから独立しています。したがって、TensorFlow Serving、TensorFlow Lite、TensorFlow.js、または他のプログラミング言語のプログラム(C、C++、Java、Go、Rust、C# などの TensorFlow API)を介したデプロイに適しています。
このガイドでは、チェックポイントの書き込みと読み取りを行う API について説明します。
セットアップ
tf.keras
トレーニング API から保存する
tf.keras
の保存と復元に関するガイドをご覧ください。
tf.keras.Model.save_weights
で TensorFlow チェックポイントを保存します。
チェックポイントを記述する
TensorFlow モデルの永続的な状態は、tf.Variable
オブジェクトに格納されます。これらは直接作成できますが、多くの場合はtf.keras.layers
やtf.keras.Model
などの高レベル API を介して作成されます。
変数を管理する最も簡単な方法は、変数を Python オブジェクトにアタッチし、それらのオブジェクトを参照することです。
tf.train.Checkpoint
、tf.keras.layers.Layer
およびtf.keras.Model
のサブクラスは、属性に割り当てられた変数を自動的に追跡します。以下の例では、単純な線形モデルを作成し、モデルのすべての変数の値を含むチェックポイントを記述します。
Model.save_weights
で、モデルチェックポイントを簡単に保存できます。
手動チェックポイント
セットアップ
tf.train.Checkpoint
のすべての機能を実演するために、トイデータセットと最適化ステップを次のように定義します。
チェックポイントオブジェクトを作成する
チェックポイントを手動で作成するには、tf.train.Checkpoint
オブジェクトを使用します。チェックポイントを設定するオブジェクトは、オブジェクトの属性として設定されます。
tf.train.CheckpointManager
は、複数のチェックポイントの管理にも役立ちます。
モデルをトレーニングおよびチェックポイントする
次のトレーニングループは、モデルとオプティマイザのインスタンスを作成し、それらをtf.train.Checkpoint
オブジェクトに集めます。それはデータの各バッチのループ内でトレーニングステップを呼び出し、定期的にチェックポイントをディスクに書き込みます。
復元してトレーニングを続ける
最初のトレーニングサイクルの後、新しいモデルとマネージャーを渡すことができますが、トレーニングはやめた所から再開します。
tf.train.CheckpointManager
オブジェクトは古いチェックポイントを削除します。上記では、最新の 3 つのチェックポイントのみを保持するように構成されています。
これらのパス、例えば'./tf_ckpts/ckpt-10'
などは、ディスク上のファイルではなく、index
ファイルのプレフィックスで、変数値を含む 1 つまたはそれ以上のデータファイルです。これらのプレフィックスは、まとめて単一のcheckpoint
ファイル('./tf_ckpts/checkpoint'
)にグループ化され、CheckpointManager
がその状態を保存します。
読み込みの仕組み
TensorFlowは、読み込まれたオブジェクトから始めて、名前付きエッジを持つ有向グラフを走査することにより、変数をチェックポイントされた値に合わせます。エッジ名は通常、オブジェクトの属性名に由来しており、self.l1 = tf.keras.layers.Dense(5)
の"l1"
などがその例です。tf.train.Checkpoint
は、tf.train.Checkpoint(step=...)
の"step"
のように、キーワード引数名を使用します。
上記の例の依存関係グラフは次のようになります。
オプティマイザは赤、通常の変数は青、オプティマイザスロット変数はオレンジで表されています。tf.train.Checkpoint
を表すノードなどは黒で示されています。
オプティマイザは赤色、通常変数は青色、オプティマイザスロット変数はオレンジ色です。他のノード、例えばtf.train.Checkpoint
を表すものは黒色です。
tf.train.Checkpoint
オブジェクトで restore
を読み出すと、リクエストされた復元がキューに入れられ、Checkpoint
オブジェクトから一致するパスが見つかるとすぐに変数値が復元されます。たとえば、ネットワークとレイヤーを介してバイアスのパスを再構築すると、上記で定義したモデルからそのバイアスのみを読み込むことができます。
これらの新しいオブジェクトの依存関係グラフは、上で書いたより大きなチェックポイントのはるかに小さなサブグラフです。 これには、バイアスと tf.train.Checkpoint
がチェックポイントに番号付けするために使用する保存カウンタのみが含まれます。
restore
は、オプションのアサーションを持つステータスオブジェクトを返します。新しい Checkpoint
で作成されたすべてのオブジェクトが復元されるため、status.assert_existing_objects_matched
がパスとなります。
チェックポイントには、レイヤーのカーネルやオプティマイザの変数など、一致しない多くのオブジェクトがあります。status.assert_consumed()
は、チェックポイントとプログラムが正確に一致する場合に限りパスするため、ここでは例外がスローされます。
復元延期 (Deferred restoration)
TensorFlow のLayer
オブジェクトは、入力形状が利用可能な場合、最初の呼び出しまで変数の作成を遅らせる可能性があります。例えば、Dense
レイヤーのカーネルの形状はレイヤーの入力形状と出力形状の両方に依存するため、コンストラクタ引数として必要な出力形状は、単独で変数を作成するために充分な情報ではありません。Layer
の呼び出しは変数の値も読み取るため、復元は変数の作成とその最初の使用の間で発生する必要があります。
このイディオムをサポートするために、tf.train.Checkpoint
は一致する変数がまだない場合、復元を延期します。
チェックポイントを手動で検査する
tf.train.load_checkpoint
は、チェックポイントのコンテンツにより低いレベルのアクセスを提供する CheckpointReader
を返します。これには各変数のキーからチェックポイントの各変数の形状と dtype へのマッピングが含まれます。変数のキーは上に表示されるグラフのようなオブジェクトパスです。
注意: チェックポイントへのより高いレベルの構造はありません。変数のパスと値のみが認識されており、models
、layers
、またはそれらがどのように接続されているかについての概念が一切ありません。
net.l1.kernel
の値に関心がある場合は、次のコードを使って値を取得できます。
また、変数の値を検査できるようにする get_tensor
メソッドも提供されています。
オブジェクトの追跡
self.l1 = tf.keras.layers.Dense(5)
のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。
self.l1 = tf.keras.layers.Dense(5)
のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。
リストとディクショナリのラッパーオブジェクトにお気づきでしょうか。これらのラッパーは基礎的なデータ構造のチェックポイント可能なバージョンです。属性に基づく読み込みと同様に、これらのラッパーは変数の値がコンテナに追加されるとすぐにそれを復元します。
追跡可能なオブジェクトには tf.train.Checkpoint
、tf.Module
およびそのサブクラス ( keras.layers.Layer
や keras.Model
など)、および認識された Python コンテナが含まれています。
dict
(およびcollections.OrderedDict
)list
tuple
(およびcollections.namedtuple
、typing.NamedTuple
)
以下のような他のコンテナタイプはサポートされていません。
collections.defaultdict
set
以下のような他のすべての Python オブジェクトは無視されます。
int
string
float
まとめ
TensorFlow オブジェクトは、それらが使用する変数の値を保存および復元するための容易で自動的な仕組みを提供します。