Copyright 2020 The TensorFlow Authors.
tf.function によるパフォーマンスの改善
TensorFlow 2 の Eager execution はデフォルトで有効になっています。ユーザーインターフェースは直感的で柔軟性に優れていますが(一度限りの演算の実行ははるかに簡単で高速に行われます)、パフォーマンスとデプロイ能力に影響がでることがあります。
プログラムからグラフを作成するには、tf.function
を使用できます。変換ツールで Python コードから Python に依存しないデータフローグラフを作成するため、パフォーマンスと移植性に優れたモデルを作成できます。また、SavedModel
を使用する際に必要となります。
このチュートリアルでは tf.function
と AutoGraph の基本的な特徴についてひととおり確認します。
主に次の内容と推奨事項について説明しています。
Eager モードでデバッグしてから、
@tf.function
でデコレートする。オブジェクトミューテーションまたはリストの追加といった Python 側の効果に依存しないこと。
tf.function
は TensorFlow 演算子と最も相性が良く、NumPy と Python 呼び出しは定数に変換される。
セットアップ
発生する可能性のあるエラーの種類を示すヘルパー関数を定義します。
基礎
使い方
定義する Function
(@tf.function
デコレーターを適用するなどして)は、コアの TensorFlow 演算とまったく変わりません。Eager での実行や勾配の計算などを行えます。
Function
をほかの Function
内で使用できます。
Function
は、特に小さな演算が多数含まれるグラフでは、Eager コードよりも高速に実行されることがありますが、高価な演算がいくつか含まれるグラフ(畳み込みなど)では、速度の差はあまり見られません。
トレーシング
このセクションは、Function
の内部動作や実装の詳細を説明します。将来的に変更する可能性がありますが、いつなぜトレーシングが発生するのかを理解しておけば、tf.function
を効果的に使用しやすくなります。
「トレーシング」とは?
Function
は TensorFlow Graph でプログラムを実行しますが、tf.Graph
は、Eager TensorFlow プログラムにユーザーが記述するすべてのものを表現することはできません。たとえば、Python はポリモーフィズムをサポートしていますが、tf.Graph
では、その入力に特定のデータ型と次元が必要です。またはコマンドラインの引数を読み取る、エラーを発生させる、より複雑な Python オブジェクトを扱うといったサイドタスクを実施しようとしても、どれも tf.Graph
で実行することはできません。
Function
はコードを 2 つの段階に分けることで、このギャップの橋渡しの役割を果たします。
「トレーシング」と呼ばれる第 1 段階において、
Function
は新しいtf.Graph
を作成します。Python コードは通常通り実行しますが、すべての TensorFlow 演算(2 つのテンソルを加算するなど)は 据え置きとなります。これらはtf.Graph
にとらわれるため、実行しません。第 2 段階では、最初の段階で据え置きとなったすべての演算を含む
tf.Graph
が実行されます。この段階は、トレーシングの段階よりもはるかに高速に行われます。
Function
は、その入力によっては必ずしも最初の段階で呼び出されたときに実行するわけではありません。この判定がどのように行われるのかについては、以下の「トレーシングの規則」をご覧ください。最初の段階を省略して 2 番目の段階のみを実行できれば、TensorFlow の高いパフォーマンスが発揮されます。
Function
がトレーシングしないと判断した場合、トレーシング段階の直後に 第 2 段階が始まるため、Function
を呼び出すと、tf.Graph
の作成と実行が行われます。後の方で、get_concrete_function
を使ってトレーシング段階のみを実行する方法を説明します。
型の異なる引数を Function
に渡すと、両方の段階が実行されます。
同じ型の引数で Function
を繰り返し呼び出すと、生成されるグラフはまったく同じになるため、TensorFlow はトレーシング段階を省略して前にトレーシングしたグラフを再利用することに注意してください。
すべての利用可能なトレースを確認するには、pretty_printed_concrete_signatures()
を使用できます。
ここまで、tf.function
が TensorFlow のグラフトレーシングロジックにキャッシュされた動的ディスパッチレイヤーを作成するのを見てきました。用語についてより具体的に説明すると、次のように言えます。
tf.Graph
は、言語に依存しない、生の移植可能な TensorFlow 計算の表現です。ConcreteFunction
はtf.Graph
をラップします。Function
はConcreteFunction
のキャッシュを管理し、入力に適したものを選択します。tf.function
は Python 関数をラップし、Function
オブジェクトを返します。トレーシングは
tf.Graph
を作成し、それをConcreteFunction
(またはトレース)をラップします。
トレーシングの規則
Function
が呼び出されると、各引数の tf.types.experimental.TraceType
を使用して呼び出し引数を既存の ConcreteFunction
に一致させます。一致する ConcreteFunction
が見つかった場合、呼び出しはそれにディスパッチされます。一致するものが見つからない場合、新しい ConcreteFunction
がトレースされます。
複数の一致が見つかった場合は、最も具体的なシグネチャが選択されます。マッチングは、たとえば C++ や Java での通常の関数呼び出しと同じように、サブタイプ化によって行われます。例えば、TensorShape([1, 2])
は TensorShape([None, None])
のサブタイプ化であるため、TensorShape([1, 2])
を使用した tf.function への呼び出しは、TensorShape([None, None])
で生成された ConcreteFunction
にディスパッチできます。しかし、TensorShape([1, None])
を持つ ConcreteFunction
も存在する場合は、より具体的であるため優先されます。
TraceType
は、次のように入力引数から決定されます。
Tensor
の場合、型はTensor
のdtype
とshape
によってパラメータ化されます。階数付けされた形状は、階数付けされていない形状のサブタイプです。固定次元は未知次元のサブタイプですVariable
の場合、型はTensor
に似ていますが、変数の一意のリソース ID も含まれています。これは、制御の依存関係を正しく設定するために必要です。Python プリミティブ値の場合、型は値自体に対応します。たとえば、値
3
のTraceType
は、int
ではなくLiteralTraceType<3>
です。list
やtuple
などの Python の順序付きコンテナの場合、型はそれらの要素の型によってパラメータ化されます。たとえば、[1, 2]
の型はListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
であり、[2, 1]
の型はListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
であり、異なります。dict
などの Python マッピングの場合、型も同じキーからのマッピングですが、実際の値ではなく値の型へのマッピングです。たとえば、{1: 2, 3: 4}
の型はMappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
です。ただし、順序付きコンテナとは異なり、{1: 2, 3: 4}
と{3: 4, 1: 2}
の型は同等です。__tf_tracing_type__
メソッドを実装する Python オブジェクトの場合、型はそのメソッドが返すものですその他の Python オブジェクトの場合、型はジェネリックの
TraceType
で、マッチング手順は以下のとおりです。まず、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の
id()
またはis
を使用します)。オブジェクトが変更された場合でも一致することに注意してください。そのため、Python オブジェクトをtf.function
の引数として使用する場合、イミュータブルを使用するのが最適です。次に、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の
==
を使用)。
この手順では、オブジェクトへの weakref のみが維持されるため、オブジェクトが範囲内または削除されていない場合にのみ機能します。
リトレーシングの制御
リトレーシングは、Function
が 2 つ以上のトレースを作成する際に発生します。これは、TensorFlow が一連の入力ごとに正しいグラフを生成する上で役立ちますが、トレーシングは高価な演算です!Function
が呼び出しごとに新しいグラフをリトレーシングすると、コードの実行は tf.function
を使用しない場合よりも遅くなってしまいます。
トレーシングの動作を制御するには、次のテクニックを使用できます。
固定の input_signature
を tf.function
に渡す
柔軟性のために未知の次元を使用する
TensorFlow は形状に基づいてテンソルを一致させるため、ワイルドカードとして None
次元を使用することで、Function
が可変サイズの入力にトレースを再利用できるようになります。可変サイズの入力は、長さの異なるシーケンスがある場合や、バッチごとに画像のサイズが異なる場合に発生します(例として、Transformer と Deep Dream チュートリアルをご覧ください)。
Python リテラルの代わりにテンソルを渡す
通常、Python 引数は、num_layers=10
または training=True
または nonlinearity='relu'
などのように、ハイパーパラメータとグラフ構造の制御に使用されます。そのため、Python 引数が変わると、当然グラフをリトレースする必要が出てきます。
しかし、Python 引数がグラフ構造の制御に使用されていない場合もあります。こういった場合、Python の値の変化によってリトレーシングがトリガーされますが、これは不要です。この、AutoGraph が動的にアンロールするトレーニングループを例に見てみましょう。トレースが何度も行われますが、生成されたグラフはまったく同じであるため、リトレーシングは不要と言えます。
リトレーシングを強制する必要がある場合は、新しい Function
を作成します。トレースは絶対に、各 Function
オブジェクト間で共有されることはありません。
トレースプロトコルを使用する
可能であれば、代わりに Python 型を tf.experimental.ExtensionType
に変換することをお勧めします。さらに、ExtensionType
の TraceType
は、それに関連付けられた tf.TypeSpec
です。したがって、必要に応じて、デフォルトの tf.TypeSpec
を単純にオーバーライドして、ExtensionType
の Tracing Protocol
を制御できます。詳細については、拡張型ガイドの ExtensionType の TypeSpec のカスタマイズセクションをご覧ください。
それ以外の場合は、Function
が特定の Python 型に関していつ再トレースする必要があるかを直接制御するために、Tracing Protocol
を自分で実装できます。
具象関数の取得
関数がトレースされるたびに新しい具象関数が作成されますが、get_concrete_function
を使うことで、具象関数を直接取得できます。
ConcreteFunction
を出力すると、入力引数(型付き)とその出力型の概要が表示されます。
また、具象関数のシグネチャを直接取得することもできます。
互換性のない型で具象トレースを使用すると、エラーが発生します。
Python 引数は、具象関数の入力シグネチャで特別に扱われていることに気づいたかもしれません。TensorFlow 2.3 より前では、Python 引数は単に具象関数のシグネチャから削除されていましたが、TensorFlow 2.3 からはシグネチャに残されたまま、トレーシング中に値セットを取るように制約されています。
グラフの取得
それぞれの具象関数は、tf.Graph
を囲む呼び出し可能なラッパーです。通常、実際の tf.Graph
オブジェクトを取得する必要はないにしろ、具象関数から簡単に取得することが可能です。
デバッグ
一般的に、コードのデバックは、tf.function
内で行うよりも、Eager モードで行う方が簡単です。Eager モードでは、tf.function
でデコレートする前に、コードがエラーなく実行することを確認しておく必要があります。デバッグプロセスを支援する目的で、tf.config.run_functions_eagerly(True)
を呼び出すと、tf.function
をグローバルに無効にして、有効にし直すことができます。
tf.function
内でのみ出現する問題を追跡する場合、次のようなヒントがあります。
従来のシンプルな Python
print
呼び出しは、トレーシング中にのみ実行されるため、関数が(リ)トレーシングされるときに追跡しやすくなります。tf.print
呼び出しは毎回実行するため、実行中の中間値の追跡に役立ちます。tf.debugging.enable_check_numerics
は、NaN と Inf がいつ作成されるかを簡単に追跡できます。pdb
(Python デバッガ)は、トレーシング中に何が起きているのかを理解する上で役立ちます。(注意:pdb
が示すのは、AutoGraph 変換ソースコードです。)
AutoGraph 変換
AutoGraph は、tf.function
内でデフォルトで利用できるようになっているライブラリで、Python の Eager コードのサブセットとグラフ対応の TensorFlow 演算に変換します。これには、if
、for
、while
などの制御フローが含まれます。
tf.cond
や tf.while_loop
などの TensorFlow 演算は機能し続けますが、制御フローは、Python で記述された場合の方が書きやすく理解しやすいことがほとんどです。
興味があれば、AutoGraph が生成するコードを検査できます。
条件文
AutoGraph は if <condition>
文を相当する tf.cond
呼び出しに変換します。この置換は、<condition>
がテンソルである場合に行われます。テンソルでない場合は、if
文は Python の条件文として実行されます。
Python 条件文はトレーシング中に実行するため、条件文のブランチが 1 つだけグラフに追加されます。AutoGraph を使用しない場合、データに依存する制御フローが存在すると、トレーシングされたこのグラフは別のブランチを取ることができません。
tf.cond
は、条件文の両方のブランチをトレーシングし、実行時に動的に 1 つのブランチを選択してグラフに追加します。トレーシングには意図しない副作用がある場合があります。詳細は、AutoGraph のトレーシング効果をご覧ください。
AutoGraph 変換の if 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
ループ
AutoGraph は、一部の for
文と while
文を相当する tf.while_loop
などの TensorFlow のループ演算に変換します。変換されない場合、for
または while
ループは Python ループとして実行されます。
この置き換えは、次の場合に行われます。
for x in y
:y
がテンソルである場合、tf.while_loop
に変換されます。y
がtf.data.Dataset
である特別なケースでは、tf.data.Dataset
演算の組み合わせが生成されます。while <condition>
:<condition>
がテンソルである場合、tf.while_loop
に変換されます。
Python ループは、トレーシング中に実行され、ループのいてレーションごとに、tf.Graph
に追加の演算が追加されます。
TensorFlow ループはループの本体をトレーシングし、実行時に実行する反復回数を動的に選択します。ループ本体は、生成された tf.Graph
に一度だけ出現します。
AutoGraph 変換の for
文と while
文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
Python データのループ
一般的な落とし穴は、tf.function
内で Python/NumPy データをループする際にあります。このループは、トレーシングプロセス中に実行し、ループのイテレーションごとにモデルのコピーを tf.Graph
に追加してしまいます。
トレーニングループ全体を tf.function
にラップしたいのであれば、データを tf.data.Dataset
としてラップし、AutoGraph にトレーニングループを動的に展開させるようにするのが最も安全な方法です。
Python/NumPy データをデータセットにラップする際は、tf.data.Dataset.from_generator
と tf.data.Dataset.from_tensor_slices
の違いに注意してください。前者は、データを Python に維持し、tf.py_function
経由で取得するため、パフォーマンスに問題がありますが、後者は、データのコピーをグラフ内の大型の tf.constant()
ノードとしてバンドル化するため、メモリに問題が現れます。
データを消費するには、TFRecordDataset
や CsvDataset
などを介してファイルからデータを読み取るのが最も効果的な方法です。そうすれば、Python を使わずに、TensorFlow 自体でデータの非同期読み込みとプリフェッチを管理できるようになります。詳細は、「tf.data
: TensorFlow 入力パイプラインを構築する」ガイドをご覧ください。
ループでの値の累積
ループの反復ごとに値を累積していくのは一般的なパターンです。通常は、Python のリストに追加したり、Python ディクショナリにエントリを追加したりして行われますが、これらは Python の副作用であるため、動的に展開されるループでは期待どおりに動作しません。動的に展開されるループの結果を累積する場合は、tf.TensorArray
を使用してください。
制限事項
TensorFlow の Function
には、設計上、いくつかの制限事項があり、Python 関数を Function
に変換する際には、注意が必要です。
Python の副作用の実行
Function
内での出力、リストへのアペンド、グローバル変数のミューテーションといった副作用は、2 回実行されたり、まったく実行しなかったりといったように、予測のつかない動作をすることがあります。また、入力セットで Function
を初めて呼び出した場合にのみ実行し、以降では、Python コードを実行せずに、トレーシング済みの tf.Graph
が再実行されてしまうこともあります。
基本的に、ロジックでは Python の副作用に依存しないようにし、トレースをデバッグするためだけに使用することをお勧めします。呼び出しごとに TensorFlow ランタイムが確実にコードを実行できるようにするには、tf.data
、tf.print
、tf.summary
、tf.Variable.assign
、tf.TensorArray
などの TensorFlow API を使用するのが最善の方法です。
Function
の呼び出しごとに Python コードを実行する場合は、tf.py_function
が脱出口です。tf.py_function
には移植性がなく、特にパフォーマンスに優れているわけでもなく、SavedModel で保存できなければ、分散型(マルチ GPU、TPU)の環境でうまく動作するわけでもありません。また、tf.py_function
はグラフに組み込む必要もあるため、すべての入力/出力をテンソルにキャストしてしまいます。
Python のグローバル変数と自由変数の変更
Python のグローバル変数と自由変数の変更は、Python の副作用としてみなされるため、トレーシング中にのみ発生します。
場合によっては気づきにくい予期しない動作が発生することがあります。以下の例では、counter
は変数のインクリメントを保護することを目的としています。ただし、これは Python 整数であり、TensorFlow オブジェクトではないため、その値は最初のトレース中にキャプチャされます。tf.function
を使用すると、assign_add
が下のグラフに無条件に記録されます。したがって、v
は、tf.function
が呼び出されるたびに 1 ずつ増加します。この問題は、tf.function
デコレータを使用して Grpah モードの Tensorflow コードを Tensorflow 2 に移行しようとする場合、Python の副作用 (例では counter
) を使用して、実行する演算を決定すると (例では、assign_add
)によく発生します。通常、ユーザーは、疑わしい数値結果を確認したり、予想よりもパフォーマンスが大幅に低下した場合に、このことに気付きます(たとえば、保護された演算に非常にコストがかかる場合)。
このような動作を回避し、期待される動作を実現するためには、tf.init_scope
を使用して演算を関数グラフの外に移動します。これにより、変数のインクリメントがトレース時間中に 1 回だけ実行されるようになります。init_scope
には、制御フローのクリアや勾配テープなどの他の副作用があることに注意してください。init_scope
を使用すると非常に複雑になり、現実的に管理できない場合があります。
まとめると、経験則として、Function
の外側で機能する整数またはリストのようなコンテナなどの Python オブジェクトのミューテーションは避けてください。代わりに、引数と TF オブジェクトを使用しましょう。たとえば、「ループでの値の累積」セクションには、リストのような演算を実装する方法の一例が示されています。
一部のケースでは、tf.Variable
である場合に状態をキャプチャして操作することができます。Keras モデルの重みは、このようにして、同じ ConcreteFunction
への呼び出しの繰り返しで更新されています。
Python イテレータとジェネレータの使用
ジェネレータやイテレータなどの多くの Python 機能は、Python ランタイムに依存して状態を追跡しています。一般的に、これらのコンストラクトは Eager モードでも期待どおりに動作しますが、Python の副作用の例であるため、トレーシング中にしか発生しません。
TensorFlow にリストコントラクト用の特別な tf.TensorArray
があるように、イテレーション用にも特別な tf.data.Iterator
があります。概要は、AutoGraph 変換をご覧ください。また、tf.data
API を使って、ジェネレータのパターンを実装できます。
tf.function のすべての出力は値を返す必要がある
tf.Variable
を除いて、tf.function はすべての出力を返す必要があります。戻り値を使用せずに関数からテンソルに直接アクセスしようとすると、「リーク」が発生します。
たとえば、以下の関数は、Python グローバル x
を介してテンソル a
を「リーク」します。
リークされた値も返される場合でもリークします。
通常、このようなリークは、Python ステートメントまたはデータ構造を使用するときに発生します。アクセスできないテンソルがリークするだけでなく、このようなステートメントは Python の副作用としてカウントされ、すべての関数呼び出しで実行されないことがあるため、間違っている可能性があります。
また、一般的に外部 Python コレクションまたはオブジェクトの変更によりローカルテンソルがリークすることもあります。
再帰的な tf.function はサポートされていない
再帰的な Function
はサポートされていないので、無限ループを引き起こす可能性があります。以下に例を示します。
再帰的な Function
が正しく動作しているように見えても、Python 関数は複数回トレースされ、パフォーマンスに影響を与える可能性があります。以下に例を示します。
既知の問題
Function
が正しく評価していない場合、以下の既知の問題が該当する可能性があります。これらの問題は、今後修正される予定です。
Python のグローバル変数と自由変数への依存
Function
は、Python 引数の新しい値で呼び出された時に新しい ConcreteFunction
を作成しますが、Python クロージャ、グローバル変数、またはその Function
の非ローカル変数に対しては作成しません。Function
への呼び出しごとに値が変化する場合でも、Function
はトレーシングされたときの値をそのまま使用してしまいます。これは、通常の Python 関数の動作とは異なります。
このため、外側の名前を閉じる代わりに引数を使用する関数プログラミングの様式をお勧めします。
グローバル値を更新する別の方法として、それを tf.Variable
にし、代わりに Variable.assign
メソッドを使用することができます。
Python オブジェクトへの依存
カスタム Python オブジェクトを引数として tf.function
に渡すことはサポートされていますが、ある制限が伴います。
特徴量を最大限にカバーするには、オブジェクトを tf.function
に渡す前に拡張型に変換することを検討してください。Python プリミティブ型と tf.nest
対応構造も使用できます。
ただし、トレーシングのルールで説明されるように、カスタム TraceType
がカスタム Python クラスによって提供されない場合、tf.function
はインスタンスベースの等価性を使用するように強制されてしまいます。そのため、変更された属性と同じオブジェクトを渡しても、新しいトレースは作成されません。
変更されたモデルインスタンスの評価に同じ Function
を使用すると、元のモデルと同じインスタンスに基づく TraceTypeがあるために不具合が生じます。
そのため、ミュータブルオブジェクト属性に依存しない Function
を書くか、そのような属性を Function
に伝達するオブジェクトにトレーシングプロトコルを実装することをお勧めします。
この方法が困難な場合は、回避策として、オブジェクトを変更するたびに新しい Function
がリトレーシングを行うようにする方法が挙げられます。
リトレーシングにはコストがかかるため、tf.Variable
をオブジェクト属性として使用することができます。こうすることで、リトレーシングを行わずに、ミュートして(変更はしません)同様の効果を得ることができます。
tf.Variables の作成
Function
は、最初の呼び出しで 1 回作成され、後続の関数呼び出しで再利用されるシングルトン tf.Variable
のみをサポートします。以下のコードスニペットは、すべての関数呼び出しで新しい tf.Variable
を作成します。これにより、ValueError
例外が発生します。
例:
この制限を回避するために使用される一般的なパターンは、Python None 値で開始し、値が None の場合は条件付きで tf.Variable
を作成することです。
複数の Keras オプティマイザとの使用
2 つ以上の Keras オプティマイザを tf.function
で使用しようとすると、「ValueError: tf.function only supports singleton tf.Variables created on the first call.
」というエラーが発生することがあります。このエラーは、オプティマイザが初めて勾配を適用する際に、内部的に tf.Variables
を作成するために発生するものです。
トレーニング中にオプティマイザを変更する必要がある場合は、回避策として、オプティマイザごとに新しい Function
を作成し、ConcreteFunction
を直接呼び出すようにすることができます。
複数の Keras モデルとの使用
また、別のモデルインスタンスを同一の Function
に渡す際に、「ValueError: tf.function only supports singleton tf.Variables created on the first call.
」というエラーも発生することがあります。
このエラーは、Keras モデル(入力形状が定義されていない)と Keras レイヤーが、初めて呼び出されるときに tf.Variables
を作成するために発生するものです。これらの変数をすでに呼び出された Function
内で初期化しようとしているのでしょう。このエラーを回避するには、モデルをトレーニングする前に、model.build(input_shape)
を呼び出して、すべての重みを初期化するようにしてください。
参考資料
Function
のエクスポートと読み込みの方法については、SavedModel ガイドをご覧ください。トレーシングの後に実行するグラフの最適化については、Grappler ガイドをご覧ください。データパイプラインの最適化方法とモデルのプロファイリングについては、Profiler ガイドをご覧ください。