Path: blob/master/site/ja/tutorials/distribute/custom_training.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
tf.distribute.Strategy を使用したカスタムトレーニング
このチュートリアルでは、複数の処理ユニット(GPU、複数のマシン、または TPU)にトレーニングを分散するための抽象化を提供する tf.distribute.Strategy
という TensorFlow API をカスタムトレーニングループで使用する方法を説明します。この例では、70,000 個の 28 x 28 のサイズの画像を含む Fashion MNIST データセットで、単純な畳み込みニューラルネットワークをトレーニングします。
カスタムトレーニングループを使用すると、より優れた制御によってトレーニングを柔軟に実行できます。また、モデルとトレーニングループのデバックもより簡単に行えるようになります。
Fashion MNIST データセットをダウンロードする
変数とグラフを分散させるストラテジーを作成する
tf.distribute.MirroredStrategy
ストラテジーはどのように機能するのでしょう?
すべての変数とモデルグラフはレプリカ上に複製されます。
入力はレプリカ全体に均等に分散されます。
各レプリカは受け取った入力の損失と勾配を計算します。
勾配は加算して全てのレプリカ間で同期されます。
同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。
注意: 下のコードはすべて 1 つのスコープ内に入れることができます。説明しやすいように、この例では複数のコードセルに分割しています。
入力パイプラインをセットアップする
データセットを作成して、それを分散します。
モデルを作成する
tf.keras.Sequential
を使用してモデルを作成します。これには、Model Subclassing API や functional API も使用できます。
損失関数を定義する
損失関数は以下の 2 つの部分で構成されていることを思い出しましょう。
予測損失は、モデルの予測が、トレーニングサンプルのバッチに対するトレーニングラベルからどれくらい外れているかを測定します。ラベル付きのサンプルごとに計算されてから、平均値を計算してバッチ全体で縮小されます。
オプションの 正則化損失の項を予測損失に追加して、モデルがトレーニングデータを過学習しないように誘導します。一般的には L2 正則化が使用されます。これは、サンプルの数に関係なく、すべてのモデルの重みの二乗和の小さな固定倍数を追加します。上記のモデルは L2 正則化を使用して、以下のトレーニングループでの処理を示しています。
単一の GPU/CPU を使った単一のマシンでのトレーニングでは、次のように動作します。
予測損失は、バッチのサンプルごとに計算され、バッチ全体で加算され、バッチサイズで除算されます。
正則化損失は、予測損失に追加されます。
合計損失の勾配は各モデルの重みに関して計算され、オプティマイザが、対応する勾配から各モデルの重みを更新します。
tf.distribute.Strategy
では、入力バッチはレプリカ間で分割されます。たとえば、GPU が 4 つあり、それぞれにモデルのレプリカが 1 つあるとします。1 つのバッチの 256 の入力サンプルは 4 つのレプリカで均等に分散されるため、各レプリカのバッチサイズは 64 となります。したがって、256 = 4*64
、または一般に GLOBAL_BATCH_SIZE = num_replicas_in_sync * BATCH_SIZE_PER_REPLICA
があることになります。
各レプリカは、それが得るトレーニングサンプルから損失を計算し、各モデルの重みに関する損失の勾配を計算します。オプティマイザは、これらの勾配をレプリカ全体で加算してから、レプリカごとにモデルの重みのコピーを更新します。
では、tf.distribute.Strategy
を使用する場合、どのように損失を計算すればよいのでしょうか。
各レプリカは、それに分散されたすべてのサンプルの予測損失を計算し、結果を加算して、
num_replicas_in_sync * BATCH_SIZE_PER_REPLICA
またはGLOBAL_BATCH_SIZE
で除算します。各レプリカは正則化損失を計算し、それを
num_replicas_in_sync
で除算します。
非分散型トレーニングに桑ベルト、すべてのレプリカ単位の損失項は 1/num_replicas_in_sync
の計数でスケールダウンされます。一方、すべての損失項、または勾配は、オプティマイザが適用する前にレプリカの数で加算されます。実際、各レプリカのオプティマイザは、GLOBAL_BATCH_SIZE
による非分散型計算が行われなかったかのようにして、同じ勾配を使用します。これは、分散型と非分散型の Keras Model.fit
の動作と同じです。より大きなグローバルバッチサイズによって学習率のスケールアップが可能になるかについて、Keras による分散型トレーニングをご覧ください。
TensorFlow では次のようにします。
この縮小とスケーリングは、Keras
Model.compile
とModel.fit
で自動的に行われます。If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the
GLOBAL_BATCH_SIZE
:scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
or you can usetf.nn.compute_average_loss
which takes the per example loss, optional sample weights, andGLOBAL_BATCH_SIZE
as arguments and returns the scaled loss.tf.keras.losses
クラスを使用すると(以下の例)、損失の縮小をNONE
またはSUM
のいずれかになるように明示的に指定する必要があります。デフォルトのAUTO
とSUM_OVER_BATCH_SIZE
はModel.fit
の外では使用できません。AUTO
は、分散型のケースで正しくなるようにどの縮小を使用するかを明示的に考える必要があるため、使用できません。SUM_OVER_BATCH_SIZE
は、現在、レプリカごとのバッチサイズでのみ除算し、レプリカ数による除算をユーザーが処理しなければならないようになっていますが、見逃す可能性があるため、使用できなくなっています。したがって、ユーザー自身が縮小を明示的に行う必要があります。
空でない
Model.losses
リストのカスタムトレーニングループを書いている場合は(重みレギュラライザなど)、加算して、レプリカ数で除算する必要があります。これは、tf.nn.scale_regularization_loss
関数を使って行えます。モデルコード自体は、レプリカの数を認識していません。
ただし、モデルは、Layer.add_loss(...)
や Layer(activity_regularizer=...)
などの Keras API によって入力に依存する正則化損失を定義できます。Layer.add_loss(...)
の場合、モデリングコードが加算されたサンプルごとの項を tf.math.reduce_mean()
などを使ってレプリカ単位(!) のバッチサイズで除算します。
特殊ケース
高度なユーザーは、以下の特殊ケースについても考慮することをお勧めします。
GLOBAL_BATCH_SIZE
よりも短い入力バッチが原因で、いくつかの場所で好ましくない例外が発生します。実際には、Dataset.repeat().batch()
を使用してエポックの境界をまたぐバッチを許可し、データセットの終了ではなくステップ数でおおよそのエポック数を定義することで、例外を回避することがよくあります。または、Dataset.batch(drop_remainder=True)
は、エポックの表記を維持しながら、最後の数個のサンプルを除外します。
説明のために、この例ではより困難なルートを選択し、短いバッチを許可するため、トレーニングエポックごとに各トレーニング サンプルが 1 回だけ含まれます。
どのデノミネーターを tf.nn.compute_average_loss()
で使用すればよいでしょうか。
上記で説明するように、いずれのオプションも、短いバッチが回避されるのであれば同等です。
多次元
labels
では、各サンプルの予測数全体でper_example_loss
を平均化する必要があります。形状が(batch_size, H, W, n_classes)
のpredictions
と形状が(batch_size, H, W)
のlabels
を持つ入力画像のすべてのピクセルに対する分類タスクがあるとした場合、per_example_loss
は次のようにして更新する必要があります:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
注意:損失の形状を確認してください。 tf.losses
/tf.keras.losses
の損失関数は、通常、入力の最後の次元の平均を返します。損失クラスはこれらの関数をラップします。 損失クラスのインスタンスを作成するときにreduction=Reduction.NONE
を渡すことは、「追加の縮小がない」ことを意味します。[batch, W, H, n_classes]
の入力形状の例を使用したカテゴリ損失の場合、n_classes
次元が縮小されます。losses.mean_squared_error
またはlosses.binary_crossentropy
のような点ごとの損失の場合、ダミー軸を用いて、[batch, W, H, 1]
を[batch, W, H]
に縮小します。ダミー軸がないと、[batch, W, H]
は誤って[batch, W]
に縮小されます。
損失と精度を追跡するメトリクスを定義する
これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。.result()
を使用して、いつでも累積統計を取得できます。
トレーニングループ
上記の例における注意点
for x in ...
コンストラクトを使用して、train_dist_dataset
とtest_dist_dataset
をイテレーションします。スケーリングされた損失は
distributed_train_step
の戻り値です。この値はtf.distribute.Strategy.reduce
呼び出しを使用してレプリカ間で集約され、次にtf.distribute.Strategy.reduce
呼び出しの戻り値を加算してバッチ間で集約されます。tf.keras.Metrics
は、tf.distribute.Strategy.run
によって実行されるtrain_step
およびtest_step
内で更新する必要があります。tf.distribute.Strategy.run
はストラテジー内の各ローカルレプリカの結果を返し、この結果の使用方法は多様です。reduce
で、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_results
を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。
最新のチェックポイントを復元してテストする
tf.distribute.Strategy
でチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。
データセットのイテレーションの代替方法
イテレータを使用する
データセット全体ではなく、任意のステップ数のイテレーションを行う場合は、iter
呼び出しを使用してイテレータを作成し、そのイテレータ上で next
を明示的に呼び出すことができます。tf.function
の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function
の外側のデータセットのイテレーションを実行する小さなスニペットを示します。
tf.function 内でイテレーションする
for x in ...
コンストラクトを使用して、または上記で行ったようにイテレータを作成して、tf.function
内で train_dist_dataset
の入力全体をイテレートすることもできます。以下の例では、1 エポックのトレーニングを @tf.function
デコレータでラップし、関数内で train_dist_dataset
をイテレーションする方法を示します。
レプリカ間でトレーニング損失を追跡する
注意: 一般的なルールとして、サンプルごとの値の追跡にはtf.keras.Metrics
を使用し、レプリカ内で集約された値を避ける必要があります。
損失スケーリングの計算が実行されるため、レプリカ間でトレーニング損失を追跡するために tf.keras.metrics.Mean
を使用することは推奨されません。
例えば、次のような特徴を持つトレーニングジョブを実行するとします。
レプリカ 2 つ
各レプリカで 2 つのサンプルを処理
結果の損失値 : 各レプリカで [2, 3] および [4, 5]
グローバルバッチサイズ = 4
損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、(2 + 3) / 4 = 1.25
および(4 + 5) / 4 = 2.25
となります。
tf.keras.metrics.Mean
を使用して 2 つのレプリカ間の損失を追跡すると、異なる結果が得られます。この例では、total
は 3.50、count
は 2 となるため、メトリックで result()
が呼び出されると、total
/count
= 1.75 となります。tf.keras.Metrics
で計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。
ガイドと例
カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。
分散型トレーニングガイド
MirroredStrategy
を使用した DenseNet の例。MirroredStrategy
とTPUStrategy
を使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。MirroredStrategy
を使用してトレーニングされ、keras_use_ctl
フラグを使用した有効化が可能な、NCF の例。MirroredStrategy
を使用してトレーニングされた、NMT の例。
その他の例は、分散型ストラテジーガイドの「例とチュートリアル」に記載されています。
次のステップ
新しい
tf.distribute.Strategy
API を独自のモデルで試してみましょう。TensorFlow モデルのパフォーマンスを最適化する方法についてのその他の詳細は、
tf.function
によるパフォーマンスの改善と TensorFlow Profiler をご覧ください。TensorFlow での分散型トレーニングガイドでは、利用可能な分散ストラテジーの概要が説明されています。