Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/guide/checkpoint.ipynb
25115 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

トレーニングのチェックポイント

「TensorFlow のモデルを保存する」という言いまわしは通常、次の 2 つのいずれかを意味します。

  1. チェックポイント、または

  2. 保存されたモデル(SavedModel)

チェックポイントは、モデルで使用されるすべてのパラメータ(tf.Variableオブジェクト)の正確な値をキャプチャします。チェックポイントにはモデルで定義された計算のいかなる記述も含まれていないため、通常は、保存されたパラメータ値を使用するソースコードが利用可能な場合に限り有用です。

一方、SavedModel 形式には、パラメータ値(チェックポイント)に加え、モデルで定義された計算のシリアライズされた記述が含まれています。この形式のモデルは、モデルを作成したソースコードから独立しています。したがって、TensorFlow Serving、TensorFlow Lite、TensorFlow.js、または他のプログラミング言語のプログラム(C、C++、Java、Go、Rust、C# などの TensorFlow API)を介したデプロイに適しています。

このガイドでは、チェックポイントの書き込みと読み取りを行う API について説明します。

セットアップ

import tensorflow as tf
class Net(tf.keras.Model): """A simple linear model.""" def __init__(self): super(Net, self).__init__() self.l1 = tf.keras.layers.Dense(5) def call(self, x): return self.l1(x)
net = Net()

tf.kerasトレーニング API から保存する

tf.kerasの保存と復元に関するガイドをご覧ください。

tf.keras.Model.save_weightsで TensorFlow チェックポイントを保存します。

net.save_weights('easy_checkpoint')

チェックポイントを記述する

TensorFlow モデルの永続的な状態は、tf.Variableオブジェクトに格納されます。これらは直接作成できますが、多くの場合はtf.keras.layerstf.keras.Modelなどの高レベル API を介して作成されます。

変数を管理する最も簡単な方法は、変数を Python オブジェクトにアタッチし、それらのオブジェクトを参照することです。

tf.train.Checkpointtf.keras.layers.Layerおよびtf.keras.Modelのサブクラスは、属性に割り当てられた変数を自動的に追跡します。以下の例では、単純な線形モデルを作成し、モデルのすべての変数の値を含むチェックポイントを記述します。

Model.save_weightsで、モデルチェックポイントを簡単に保存できます。

手動チェックポイント

セットアップ

tf.train.Checkpoint のすべての機能を実演するために、トイデータセットと最適化ステップを次のように定義します。

def toy_dataset(): inputs = tf.range(10.)[:, None] labels = inputs * 5. + tf.range(5.)[None, :] return tf.data.Dataset.from_tensor_slices( dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer): """Trains `net` on `example` using `optimizer`.""" with tf.GradientTape() as tape: output = net(example['x']) loss = tf.reduce_mean(tf.abs(output - example['y'])) variables = net.trainable_variables gradients = tape.gradient(loss, variables) optimizer.apply_gradients(zip(gradients, variables)) return loss

チェックポイントオブジェクトを作成する

チェックポイントを手動で作成するには、tf.train.Checkpoint オブジェクトを使用します。チェックポイントを設定するオブジェクトは、オブジェクトの属性として設定されます。

tf.train.CheckpointManagerは、複数のチェックポイントの管理にも役立ちます。

opt = tf.keras.optimizers.Adam(0.1) dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

モデルをトレーニングおよびチェックポイントする

次のトレーニングループは、モデルとオプティマイザのインスタンスを作成し、それらをtf.train.Checkpointオブジェクトに集めます。それはデータの各バッチのループ内でトレーニングステップを呼び出し、定期的にチェックポイントをディスクに書き込みます。

def train_and_checkpoint(net, manager): ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") for _ in range(50): example = next(iterator) loss = train_step(net, example, opt) ckpt.step.assign_add(1) if int(ckpt.step) % 10 == 0: save_path = manager.save() print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path)) print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)

復元してトレーニングを続ける

最初のトレーニングサイクルの後、新しいモデルとマネージャーを渡すことができますが、トレーニングはやめた所から再開します。

opt = tf.keras.optimizers.Adam(0.1) net = Net() dataset = toy_dataset() iterator = iter(dataset) ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3) train_and_checkpoint(net, manager)

tf.train.CheckpointManagerオブジェクトは古いチェックポイントを削除します。上記では、最新の 3 つのチェックポイントのみを保持するように構成されています。

print(manager.checkpoints) # List the three remaining checkpoints

これらのパス、例えば'./tf_ckpts/ckpt-10'などは、ディスク上のファイルではなく、indexファイルのプレフィックスで、変数値を含む 1 つまたはそれ以上のデータファイルです。これらのプレフィックスは、まとめて単一のcheckpointファイル('./tf_ckpts/checkpoint')にグループ化され、CheckpointManagerがその状態を保存します。

!ls ./tf_ckpts

読み込みの仕組み

TensorFlowは、読み込まれたオブジェクトから始めて、名前付きエッジを持つ有向グラフを走査することにより、変数をチェックポイントされた値に合わせます。エッジ名は通常、オブジェクトの属性名に由来しており、self.l1 = tf.keras.layers.Dense(5)"l1"などがその例です。tf.train.Checkpointは、tf.train.Checkpoint(step=...)"step"のように、キーワード引数名を使用します。

上記の例の依存関係グラフは次のようになります。

Visualization of the dependency graph for the example training loop

オプティマイザは赤、通常の変数は青、オプティマイザスロット変数はオレンジで表されています。tf.train.Checkpoint を表すノードなどは黒で示されています。

オプティマイザは赤色、通常変数は青色、オプティマイザスロット変数はオレンジ色です。他のノード、例えばtf.train.Checkpointを表すものは黒色です。

tf.train.Checkpoint オブジェクトで restore を読み出すと、リクエストされた復元がキューに入れられ、Checkpoint オブジェクトから一致するパスが見つかるとすぐに変数値が復元されます。たとえば、ネットワークとレイヤーを介してバイアスのパスを再構築すると、上記で定義したモデルからそのバイアスのみを読み込むことができます。

to_restore = tf.Variable(tf.zeros([5])) print(to_restore.numpy()) # All zeros fake_layer = tf.train.Checkpoint(bias=to_restore) fake_net = tf.train.Checkpoint(l1=fake_layer) new_root = tf.train.Checkpoint(net=fake_net) status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/')) print(to_restore.numpy()) # We get the restored value now

これらの新しいオブジェクトの依存関係グラフは、上で書いたより大きなチェックポイントのはるかに小さなサブグラフです。 これには、バイアスと tf.train.Checkpoint がチェックポイントに番号付けするために使用する保存カウンタのみが含まれます。

Visualization of a subgraph for the bias variable

restore は、オプションのアサーションを持つステータスオブジェクトを返します。新しい Checkpoint で作成されたすべてのオブジェクトが復元されるため、status.assert_existing_objects_matched がパスとなります。

status.assert_existing_objects_matched()

チェックポイントには、レイヤーのカーネルやオプティマイザの変数など、一致しない多くのオブジェクトがあります。status.assert_consumed() は、チェックポイントとプログラムが正確に一致する場合に限りパスするため、ここでは例外がスローされます。

復元延期 (Deferred restoration)

TensorFlow のLayerオブジェクトは、入力形状が利用可能な場合、最初の呼び出しまで変数の作成を遅らせる可能性があります。例えば、Denseレイヤーのカーネルの形状はレイヤーの入力形状と出力形状の両方に依存するため、コンストラクタ引数として必要な出力形状は、単独で変数を作成するために充分な情報ではありません。Layerの呼び出しは変数の値も読み取るため、復元は変数の作成とその最初の使用の間で発生する必要があります。

このイディオムをサポートするために、tf.train.Checkpoint は一致する変数がまだない場合、復元を延期します。

deferred_restore = tf.Variable(tf.zeros([1, 5])) print(deferred_restore.numpy()) # Not restored; still zeros fake_layer.kernel = deferred_restore print(deferred_restore.numpy()) # Restored

チェックポイントを手動で検査する

tf.train.load_checkpoint は、チェックポイントのコンテンツにより低いレベルのアクセスを提供する CheckpointReader を返します。これには各変数のキーからチェックポイントの各変数の形状と dtype へのマッピングが含まれます。変数のキーは上に表示されるグラフのようなオブジェクトパスです。

注意: チェックポイントへのより高いレベルの構造はありません。変数のパスと値のみが認識されており、modelslayers、またはそれらがどのように接続されているかについての概念が一切ありません。

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))

net.l1.kernel の値に関心がある場合は、次のコードを使って値を取得できます。

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE' print("Shape:", shape_from_key[key]) print("Dtype:", dtype_from_key[key].name)

また、変数の値を検査できるようにする get_tensor メソッドも提供されています。

reader.get_tensor(key)

オブジェクトの追跡

self.l1 = tf.keras.layers.Dense(5)のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。

self.l1 = tf.keras.layers.Dense(5)のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。

save = tf.train.Checkpoint() save.listed = [tf.Variable(1.)] save.listed.append(tf.Variable(2.)) save.mapped = {'one': save.listed[0]} save.mapped['two'] = save.listed[1] save_path = save.save('./tf_list_example') restore = tf.train.Checkpoint() v2 = tf.Variable(0.) assert 0. == v2.numpy() # Not restored yet restore.mapped = {'two': v2} restore.restore(save_path) assert 2. == v2.numpy()

リストとディクショナリのラッパーオブジェクトにお気づきでしょうか。これらのラッパーは基礎的なデータ構造のチェックポイント可能なバージョンです。属性に基づく読み込みと同様に、これらのラッパーは変数の値がコンテナに追加されるとすぐにそれを復元します。

restore.listed = [] print(restore.listed) # ListWrapper([]) v1 = tf.Variable(0.) restore.listed.append(v1) # Restores v1, from restore() in the previous cell assert 1. == v1.numpy()

追跡可能なオブジェクトには tf.train.Checkpointtf.Module およびそのサブクラス ( keras.layers.Layerkeras.Model など)、および認識された Python コンテナが含まれています。

  • dict (および collections.OrderedDict)

  • list

  • tuple (および collections.namedtupletyping.NamedTuple)

以下のような他のコンテナタイプはサポートされていません

  • collections.defaultdict

  • set

以下のような他のすべての Python オブジェクトは無視されます

  • int

  • string

  • float

まとめ

TensorFlow オブジェクトは、それらが使用する変数の値を保存および復元するための容易で自動的な仕組みを提供します。