Copyright 2020 The TensorFlow Authors.
勾配と自動微分の基礎
自動微分と勾配
自動微分は、ニューラルネットワークをトレーニングするバックプロパゲーションなどの機械学習アルゴリズムの実装に有用です。
このガイドでは、特に Eager execution において、TensorFlow を使用して勾配を計算する方法について説明します。
セットアップ
勾配を計算する
TensorFlow は、自動的に微分するために、フォワードパス中にどのような演算がどの順序で行われたかを覚えておく必要があります。その後、TensorFlow は逆方向パス中にこの演算のリストを逆順に走査し、勾配を計算します。
勾配テープ
TensorFlow には、一部の入力、通常はtf.Variable
に関する計算の勾配を計算する、自動微分のための tf.GradientTape API があります。TensorFlow は、tf.GradientTape
のコンテキスト内で実行される関連の演算を「テープ」に「記録」します。その後、TensorFlow はそのテープを使い、リバースモード微分を使用して「記録」された計算の勾配を計算します。
簡単な例を示します。
いくつかの演算を記録してから、GradientTape.gradient(target, sources)
を使用して、いくつかのソース(多くの場合はモデルの変数)に対するいくつかのターゲット(多くの場合は損失)の勾配を計算します。
上記の例ではスカラーを使用していますが、どのテンソルでもtf.GradientTape
は簡単に機能します。
両方の変数に関する loss
勾配を取得するには、両方をソースとして gradient
メソッドに渡すことができます。テープはソースがどのように渡されるかについては柔軟であり、リストまたはディクショナリのネストされた組み合わせを受け入れ、同じ方法で構造化された勾配を返します(tf.nest
を参照)。
各ソースに関する勾配には、ソースの形状があります。
下記もまた勾配計算ですが、この例では変数のディクショナリを渡します。
デフォルトの動作では、トレーニング可能なtf.Variable
にアクセスした後、全ての演算を記録します。その理由は次の通りです。
逆方向パスの勾配を計算するために、テープはフォワードパス中のどの演算を記録するか、知っておく必要があります。
テープは中間出力への参照を保持するため、不要な演算を記録する必要はありません。
最も一般的な使用例として、モデルのトレーニング可能なすべての変数に対する損失の勾配の計算があります。
たとえば、次の例ではデフォルトで tf.Tensor
が「監視」されておらず、tf.Variable
はトレーニング対象外であるため、勾配の計算ができません。
GradientTape.watched_variables
メソッドを使用すると、テープが監視している変数を一覧表示できます。
tf.GradientTape
は、ユーザーが監視対象や非監視対象を制御できるフックを提供します。
tf.Tensor
に関する勾配を記録するには、GradientTape.watch(x)
を呼び出す必要があります。
逆に、すべてのtf.Variables
を監視するデフォルトの動作を無効にするには、勾配テープの作成時にwatch_accessed_variables=False
を設定します。この計算には 2 つの変数を使用しますが、1 つの変数の勾配のみに接続します。
GradientTape.watch
はx0
で呼び出されなかったため、それに関する勾配は計算されません。
中間結果
tf.GradientTape
コンテキスト中で計算された中間の値に対する出力の勾配をリクエストすることもできます。
デフォルトでは、ある GradientTape
に保持されたリソースは、GradientTape.gradient
メソッドが呼び出されるとすぐに解放されます。同じ計算で複数の勾配を計算する場合は、persistent=True
を指定した勾配テープを作成します。こうすると、テープオブジェクトのガベージコレクションを実行するときにリソースが解放されるため、gradient
メソッドを何度も呼び出すことができます。例を示します。
パフォーマンスに関する注記
勾配テープコンテキスト内の演算の実行に伴う、少量のオーバーヘッドがあります。ほとんどの Eager Execution では、これは目立ったコストにはなりませんが、それでもテープコンテキストは必要な領域のみで使用すべきです。
勾配テープは、メモリを使用して入力と出力を含む中間結果を格納し、逆方向パス中に使用します。
ReLU
などの一部の演算は、中間結果を保持する必要がないため、効率を上げるためにフォワードパス中に削除されます。ただし、テープにpersistent=True
を使用している場合は何も破棄されないため、ピーク時のメモリ使用量が高くなります。
非スカラーのターゲットの勾配
勾配は基本的にスカラーの演算です。
したがって、複数のターゲットの勾配を求める場合、各ソースの結果は次のようになります。
ターゲットの合計の勾配、あるいは
各ターゲットの勾配の合計
同様に、ターゲットがスカラーでない場合は、合計の勾配が計算されます。
これにより、損失を集めた合計の勾配、または要素ごとの損失計算の合計の勾配の取得が容易になります。
アイテムごとに個別の勾配が必要な場合は、Jacobians(ヤコビアン)をご覧ください。
場合によっては、ヤコビアンをスキップすることができます。要素についての計算の場合、各要素は独立しているため、合計の勾配は入力要素に対する各要素の導関数を与えます。
制御フロー
勾配テープは実行時に演算を記録するため、Python 制御フローは自然に処理されます(if
文や while
文など)。
ここでは、if
の各ブランチで異なる変数が使用されています。勾配は使用された変数にのみ接続します。
制御文自体は微分不可能なため、勾配ベースのオプティマイザから見えないということに注意してください。
上記の例のx
の値次第で、テープはresult = v0
またはresult = v1**2
を記録します。x
に対する勾配は常にNone
です。
gradient
がNone
を返すケース
ターゲットがソースに接続されていない場合、勾配はNone
になります。
ここでは明らかにz
がx
に接続されていませんが、これほど明白ではないにしても勾配が非接続になりうる場合がいくつかあります。
1. 変数をテンソルに置換した場合
「テープの監視対象を制御する」のセクションで説明したようにテープは自動的にtf.Variable
を監視しますが、tf.Tensor
は監視しません。
よくあるエラーの 1 つは、tf.Variable
を更新するためにVariable.assign
を使用する代わりに、tf.Variable
をtf.Tensor
で置き換えてしまうことです。例を示します。
2. TensorFlowの外で計算をした
計算が TensorFlow から出てしまうと、テープは勾配パスを記録することができません。例を示します。
3. 整数または文字列を使用して勾配を取得した
整数と文字列は微分不可能です。計算パスがこれらのデータ型を使用する場合は、勾配を取得できません。
文字列は微分不可能だと知っていても、dtype
を指定していない場合に、うっかりint
定数や変数を作成してしまう可能性があります。
TensorFlow は、型間で自動的にキャストしないため、実際には、欠損した勾配の代わりに型のエラーが表示されることがよくあります。
4. ステートフルオブジェクトを使用して勾配を取得した
状態は勾配を停止します。ステートフルオブジェクトから読み取る場合、テープはその時点の状態のみを確認し、その状態に至るまでの履歴を確認できません。
tf.Tensor
は不変で、一旦作成したテンソルは変更できません。値はありますが、状態はありません。これまでに説明したすべての演算もステートレスで、tf.matmul
の出力はその入力のみに依存します。
tf.Variable
tf.Variable には内部状態とその値があるため、変数を使用するとその状態が読み取られます。変数に関する勾配を計算するのは通例ですが、変数の状態によって勾配の計算をさかのぼって行うことはできません。次に例を示します。
同様に、tf.data.Dataset
イテレータと tf.queue
はステートフルであるため、それらを通過するテンソルのすべての勾配を停止します。
勾配が登録されていない
一部のtf.Operation
は微分不可能として登録されているため、None
を返します。その他は勾配の登録がされていません。
tf.raw_ops のページには、勾配を登録する低レベルの演算が示されています。
勾配が登録されていない浮動小数点演算を介して勾配を取得しようとすると、テープは暗黙的にNone
を返す代わりにエラーをスローします。これにより、何かが間違っていることが分かります。
たとえば、tf.image.adjust_contrast
関数は、raw_ops.AdjustContrastv2
をラップしており、勾配があっても、その勾配は実装されていません。
この演算で微分する必要がある場合は、勾配を実装して登録する(tf.RegisterGradient
を使用)か、ほかの演算を使用して関数を再実装する必要があります。
None の代わりにゼロを取得する
場合によっては、接続されていない勾配でNone
ではなく 0(ゼロ)を取得すると便利です。unconnected_gradients
引数を使用すると、接続されていない勾配がある場合に何を返すかを決めることができます。