Copyright 2020 The TensorFlow Authors.
tf.function によるパフォーマンスの改善
TensorFlow 2 の Eager execution はデフォルトで有効になっています。ユーザーインターフェースは直感的で柔軟性に優れていますが(一度限りの演算の実行ははるかに簡単で高速に行われます)、パフォーマンスとデプロイ能力に影響がでることがあります。
プログラムからグラフを作成するには、tf.function を使用できます。変換ツールで Python コードから Python に依存しないデータフローグラフを作成するため、パフォーマンスと移植性に優れたモデルを作成できます。また、SavedModel を使用する際に必要となります。
このチュートリアルでは tf.function と AutoGraph の基本的な特徴についてひととおり確認します。
主に次の内容と推奨事項について説明しています。
Eager モードでデバッグしてから、
@tf.functionでデコレートする。オブジェクトミューテーションまたはリストの追加といった Python 側の効果に依存しないこと。
tf.functionは TensorFlow 演算子と最も相性が良く、NumPy と Python 呼び出しは定数に変換される。
セットアップ
発生する可能性のあるエラーの種類を示すヘルパー関数を定義します。
基礎
使い方
定義する Function(@tf.function デコレーターを適用するなどして)は、コアの TensorFlow 演算とまったく変わりません。Eager での実行や勾配の計算などを行えます。
Function をほかの Function 内で使用できます。
Function は、特に小さな演算が多数含まれるグラフでは、Eager コードよりも高速に実行されることがありますが、高価な演算がいくつか含まれるグラフ(畳み込みなど)では、速度の差はあまり見られません。
トレーシング
このセクションは、Function の内部動作や実装の詳細を説明します。将来的に変更する可能性がありますが、いつなぜトレーシングが発生するのかを理解しておけば、tf.function を効果的に使用しやすくなります。
「トレーシング」とは?
Function は TensorFlow Graph でプログラムを実行しますが、tf.Graph は、Eager TensorFlow プログラムにユーザーが記述するすべてのものを表現することはできません。たとえば、Python はポリモーフィズムをサポートしていますが、tf.Graph では、その入力に特定のデータ型と次元が必要です。またはコマンドラインの引数を読み取る、エラーを発生させる、より複雑な Python オブジェクトを扱うといったサイドタスクを実施しようとしても、どれも tf.Graph で実行することはできません。
Function はコードを 2 つの段階に分けることで、このギャップの橋渡しの役割を果たします。
「トレーシング」と呼ばれる第 1 段階において、
Functionは新しいtf.Graphを作成します。Python コードは通常通り実行しますが、すべての TensorFlow 演算(2 つのテンソルを加算するなど)は 据え置きとなります。これらはtf.Graphにとらわれるため、実行しません。第 2 段階では、最初の段階で据え置きとなったすべての演算を含む
tf.Graphが実行されます。この段階は、トレーシングの段階よりもはるかに高速に行われます。
Function は、その入力によっては必ずしも最初の段階で呼び出されたときに実行するわけではありません。この判定がどのように行われるのかについては、以下の「トレーシングの規則」をご覧ください。最初の段階を省略して 2 番目の段階のみを実行できれば、TensorFlow の高いパフォーマンスが発揮されます。
Function がトレーシングしないと判断した場合、トレーシング段階の直後に 第 2 段階が始まるため、Function を呼び出すと、tf.Graph の作成と実行が行われます。後の方で、get_concrete_function を使ってトレーシング段階のみを実行する方法を説明します。
型の異なる引数を Function に渡すと、両方の段階が実行されます。
同じ型の引数で Function を繰り返し呼び出すと、生成されるグラフはまったく同じになるため、TensorFlow はトレーシング段階を省略して前にトレーシングしたグラフを再利用することに注意してください。
すべての利用可能なトレースを確認するには、pretty_printed_concrete_signatures() を使用できます。
ここまで、tf.function が TensorFlow のグラフトレーシングロジックにキャッシュされた動的ディスパッチレイヤーを作成するのを見てきました。用語についてより具体的に説明すると、次のように言えます。
tf.Graphは、言語に依存しない、生の移植可能な TensorFlow 計算の表現です。ConcreteFunctionはtf.Graphをラップします。FunctionはConcreteFunctionのキャッシュを管理し、入力に適したものを選択します。tf.functionは Python 関数をラップし、Functionオブジェクトを返します。トレーシングは
tf.Graphを作成し、それをConcreteFunction(またはトレース)をラップします。
トレーシングの規則
Function が呼び出されると、各引数の tf.types.experimental.TraceType を使用して呼び出し引数を既存の ConcreteFunction に一致させます。一致する ConcreteFunction が見つかった場合、呼び出しはそれにディスパッチされます。一致するものが見つからない場合、新しい ConcreteFunction がトレースされます。
複数の一致が見つかった場合は、最も具体的なシグネチャが選択されます。マッチングは、たとえば C++ や Java での通常の関数呼び出しと同じように、サブタイプ化によって行われます。例えば、TensorShape([1, 2]) は TensorShape([None, None]) のサブタイプ化であるため、TensorShape([1, 2]) を使用した tf.function への呼び出しは、TensorShape([None, None]) で生成された ConcreteFunction にディスパッチできます。しかし、TensorShape([1, None]) を持つ ConcreteFunction も存在する場合は、より具体的であるため優先されます。
TraceType は、次のように入力引数から決定されます。
Tensorの場合、型はTensorのdtypeとshapeによってパラメータ化されます。階数付けされた形状は、階数付けされていない形状のサブタイプです。固定次元は未知次元のサブタイプですVariableの場合、型はTensorに似ていますが、変数の一意のリソース ID も含まれています。これは、制御の依存関係を正しく設定するために必要です。Python プリミティブ値の場合、型は値自体に対応します。たとえば、値
3のTraceTypeは、intではなくLiteralTraceType<3>です。listやtupleなどの Python の順序付きコンテナの場合、型はそれらの要素の型によってパラメータ化されます。たとえば、[1, 2]の型はListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>であり、[2, 1]の型はListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>であり、異なります。dictなどの Python マッピングの場合、型も同じキーからのマッピングですが、実際の値ではなく値の型へのマッピングです。たとえば、{1: 2, 3: 4}の型はMappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>です。ただし、順序付きコンテナとは異なり、{1: 2, 3: 4}と{3: 4, 1: 2}の型は同等です。__tf_tracing_type__メソッドを実装する Python オブジェクトの場合、型はそのメソッドが返すものですその他の Python オブジェクトの場合、型はジェネリックの
TraceTypeで、マッチング手順は以下のとおりです。まず、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の
id()またはisを使用します)。オブジェクトが変更された場合でも一致することに注意してください。そのため、Python オブジェクトをtf.functionの引数として使用する場合、イミュータブルを使用するのが最適です。次に、オブジェクトが前のトレースで使用されたオブジェクトと同じであるかをチェックします(Python の
==を使用)。
この手順では、オブジェクトへの weakref のみが維持されるため、オブジェクトが範囲内または削除されていない場合にのみ機能します。
リトレーシングの制御
リトレーシングは、Function が 2 つ以上のトレースを作成する際に発生します。これは、TensorFlow が一連の入力ごとに正しいグラフを生成する上で役立ちますが、トレーシングは高価な演算です!Function が呼び出しごとに新しいグラフをリトレーシングすると、コードの実行は tf.function を使用しない場合よりも遅くなってしまいます。
トレーシングの動作を制御するには、次のテクニックを使用できます。
固定の input_signature を tf.function に渡す
柔軟性のために未知の次元を使用する
TensorFlow は形状に基づいてテンソルを一致させるため、ワイルドカードとして None 次元を使用することで、Function が可変サイズの入力にトレースを再利用できるようになります。可変サイズの入力は、長さの異なるシーケンスがある場合や、バッチごとに画像のサイズが異なる場合に発生します(例として、Transformer と Deep Dream チュートリアルをご覧ください)。
Python リテラルの代わりにテンソルを渡す
通常、Python 引数は、num_layers=10 または training=True または nonlinearity='relu' などのように、ハイパーパラメータとグラフ構造の制御に使用されます。そのため、Python 引数が変わると、当然グラフをリトレースする必要が出てきます。
しかし、Python 引数がグラフ構造の制御に使用されていない場合もあります。こういった場合、Python の値の変化によってリトレーシングがトリガーされますが、これは不要です。この、AutoGraph が動的にアンロールするトレーニングループを例に見てみましょう。トレースが何度も行われますが、生成されたグラフはまったく同じであるため、リトレーシングは不要と言えます。
リトレーシングを強制する必要がある場合は、新しい Function を作成します。トレースは絶対に、各 Function オブジェクト間で共有されることはありません。
トレースプロトコルを使用する
可能であれば、代わりに Python 型を tf.experimental.ExtensionType に変換することをお勧めします。さらに、ExtensionType の TraceType は、それに関連付けられた tf.TypeSpec です。したがって、必要に応じて、デフォルトの tf.TypeSpec を単純にオーバーライドして、ExtensionType の Tracing Protocol を制御できます。詳細については、拡張型ガイドの ExtensionType の TypeSpec のカスタマイズセクションをご覧ください。
それ以外の場合は、Function が特定の Python 型に関していつ再トレースする必要があるかを直接制御するために、Tracing Protocol を自分で実装できます。
具象関数の取得
関数がトレースされるたびに新しい具象関数が作成されますが、get_concrete_function を使うことで、具象関数を直接取得できます。
ConcreteFunction を出力すると、入力引数(型付き)とその出力型の概要が表示されます。
また、具象関数のシグネチャを直接取得することもできます。
互換性のない型で具象トレースを使用すると、エラーが発生します。
Python 引数は、具象関数の入力シグネチャで特別に扱われていることに気づいたかもしれません。TensorFlow 2.3 より前では、Python 引数は単に具象関数のシグネチャから削除されていましたが、TensorFlow 2.3 からはシグネチャに残されたまま、トレーシング中に値セットを取るように制約されています。
グラフの取得
それぞれの具象関数は、tf.Graph を囲む呼び出し可能なラッパーです。通常、実際の tf.Graph オブジェクトを取得する必要はないにしろ、具象関数から簡単に取得することが可能です。
デバッグ
一般的に、コードのデバックは、tf.function 内で行うよりも、Eager モードで行う方が簡単です。Eager モードでは、tf.function でデコレートする前に、コードがエラーなく実行することを確認しておく必要があります。デバッグプロセスを支援する目的で、tf.config.run_functions_eagerly(True) を呼び出すと、tf.function をグローバルに無効にして、有効にし直すことができます。
tf.function 内でのみ出現する問題を追跡する場合、次のようなヒントがあります。
従来のシンプルな Python
print呼び出しは、トレーシング中にのみ実行されるため、関数が(リ)トレーシングされるときに追跡しやすくなります。tf.print呼び出しは毎回実行するため、実行中の中間値の追跡に役立ちます。tf.debugging.enable_check_numericsは、NaN と Inf がいつ作成されるかを簡単に追跡できます。pdb(Python デバッガ)は、トレーシング中に何が起きているのかを理解する上で役立ちます。(注意:pdbが示すのは、AutoGraph 変換ソースコードです。)
AutoGraph 変換
AutoGraph は、tf.function 内でデフォルトで利用できるようになっているライブラリで、Python の Eager コードのサブセットとグラフ対応の TensorFlow 演算に変換します。これには、if、for、while などの制御フローが含まれます。
tf.cond や tf.while_loop などの TensorFlow 演算は機能し続けますが、制御フローは、Python で記述された場合の方が書きやすく理解しやすいことがほとんどです。
興味があれば、AutoGraph が生成するコードを検査できます。
条件文
AutoGraph は if <condition> 文を相当する tf.cond 呼び出しに変換します。この置換は、<condition> がテンソルである場合に行われます。テンソルでない場合は、if 文は Python の条件文として実行されます。
Python 条件文はトレーシング中に実行するため、条件文のブランチが 1 つだけグラフに追加されます。AutoGraph を使用しない場合、データに依存する制御フローが存在すると、トレーシングされたこのグラフは別のブランチを取ることができません。
tf.cond は、条件文の両方のブランチをトレーシングし、実行時に動的に 1 つのブランチを選択してグラフに追加します。トレーシングには意図しない副作用がある場合があります。詳細は、AutoGraph のトレーシング効果をご覧ください。
AutoGraph 変換の if 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
ループ
AutoGraph は、一部の for 文と while 文を相当する tf.while_loop などの TensorFlow のループ演算に変換します。変換されない場合、for または while ループは Python ループとして実行されます。
この置き換えは、次の場合に行われます。
for x in y:yがテンソルである場合、tf.while_loopに変換されます。yがtf.data.Datasetである特別なケースでは、tf.data.Dataset演算の組み合わせが生成されます。while <condition>:<condition>がテンソルである場合、tf.while_loopに変換されます。
Python ループは、トレーシング中に実行され、ループのいてレーションごとに、tf.Graph に追加の演算が追加されます。
TensorFlow ループはループの本体をトレーシングし、実行時に実行する反復回数を動的に選択します。ループ本体は、生成された tf.Graph に一度だけ出現します。
AutoGraph 変換の for 文と while 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
Python データのループ
一般的な落とし穴は、tf.function 内で Python/NumPy データをループする際にあります。このループは、トレーシングプロセス中に実行し、ループのイテレーションごとにモデルのコピーを tf.Graph に追加してしまいます。
トレーニングループ全体を tf.function にラップしたいのであれば、データを tf.data.Dataset としてラップし、AutoGraph にトレーニングループを動的に展開させるようにするのが最も安全な方法です。
Python/NumPy データをデータセットにラップする際は、tf.data.Dataset.from_generator と tf.data.Dataset.from_tensor_slices の違いに注意してください。前者は、データを Python に維持し、tf.py_function 経由で取得するため、パフォーマンスに問題がありますが、後者は、データのコピーをグラフ内の大型の tf.constant() ノードとしてバンドル化するため、メモリに問題が現れます。
データを消費するには、TFRecordDataset や CsvDataset などを介してファイルからデータを読み取るのが最も効果的な方法です。そうすれば、Python を使わずに、TensorFlow 自体でデータの非同期読み込みとプリフェッチを管理できるようになります。詳細は、「tf.data: TensorFlow 入力パイプラインを構築する」ガイドをご覧ください。
ループでの値の累積
ループの反復ごとに値を累積していくのは一般的なパターンです。通常は、Python のリストに追加したり、Python ディクショナリにエントリを追加したりして行われますが、これらは Python の副作用であるため、動的に展開されるループでは期待どおりに動作しません。動的に展開されるループの結果を累積する場合は、tf.TensorArray を使用してください。
制限事項
TensorFlow の Function には、設計上、いくつかの制限事項があり、Python 関数を Function に変換する際には、注意が必要です。
Python の副作用の実行
Function 内での出力、リストへのアペンド、グローバル変数のミューテーションといった副作用は、2 回実行されたり、まったく実行しなかったりといったように、予測のつかない動作をすることがあります。また、入力セットで Function を初めて呼び出した場合にのみ実行し、以降では、Python コードを実行せずに、トレーシング済みの tf.Graph が再実行されてしまうこともあります。
基本的に、ロジックでは Python の副作用に依存しないようにし、トレースをデバッグするためだけに使用することをお勧めします。呼び出しごとに TensorFlow ランタイムが確実にコードを実行できるようにするには、tf.data、tf.print、tf.summary、tf.Variable.assign、tf.TensorArray などの TensorFlow API を使用するのが最善の方法です。
Function の呼び出しごとに Python コードを実行する場合は、tf.py_function が脱出口です。tf.py_function には移植性がなく、特にパフォーマンスに優れているわけでもなく、SavedModel で保存できなければ、分散型(マルチ GPU、TPU)の環境でうまく動作するわけでもありません。また、tf.py_function はグラフに組み込む必要もあるため、すべての入力/出力をテンソルにキャストしてしまいます。
Python のグローバル変数と自由変数の変更
Python のグローバル変数と自由変数の変更は、Python の副作用としてみなされるため、トレーシング中にのみ発生します。
場合によっては気づきにくい予期しない動作が発生することがあります。以下の例では、counter は変数のインクリメントを保護することを目的としています。ただし、これは Python 整数であり、TensorFlow オブジェクトではないため、その値は最初のトレース中にキャプチャされます。tf.function を使用すると、assign_add が下のグラフに無条件に記録されます。したがって、v は、tf.function が呼び出されるたびに 1 ずつ増加します。この問題は、tf.function デコレータを使用して Grpah モードの Tensorflow コードを Tensorflow 2 に移行しようとする場合、Python の副作用 (例では counter ) を使用して、実行する演算を決定すると (例では、assign_add )によく発生します。通常、ユーザーは、疑わしい数値結果を確認したり、予想よりもパフォーマンスが大幅に低下した場合に、このことに気付きます(たとえば、保護された演算に非常にコストがかかる場合)。
このような動作を回避し、期待される動作を実現するためには、tf.init_scope を使用して演算を関数グラフの外に移動します。これにより、変数のインクリメントがトレース時間中に 1 回だけ実行されるようになります。init_scope には、制御フローのクリアや勾配テープなどの他の副作用があることに注意してください。init_scope を使用すると非常に複雑になり、現実的に管理できない場合があります。
まとめると、経験則として、Function の外側で機能する整数またはリストのようなコンテナなどの Python オブジェクトのミューテーションは避けてください。代わりに、引数と TF オブジェクトを使用しましょう。たとえば、「ループでの値の累積」セクションには、リストのような演算を実装する方法の一例が示されています。
一部のケースでは、tf.Variable である場合に状態をキャプチャして操作することができます。Keras モデルの重みは、このようにして、同じ ConcreteFunction への呼び出しの繰り返しで更新されています。
Python イテレータとジェネレータの使用
ジェネレータやイテレータなどの多くの Python 機能は、Python ランタイムに依存して状態を追跡しています。一般的に、これらのコンストラクトは Eager モードでも期待どおりに動作しますが、Python の副作用の例であるため、トレーシング中にしか発生しません。
TensorFlow にリストコントラクト用の特別な tf.TensorArray があるように、イテレーション用にも特別な tf.data.Iterator があります。概要は、AutoGraph 変換をご覧ください。また、tf.data API を使って、ジェネレータのパターンを実装できます。
tf.function のすべての出力は値を返す必要がある
tf.Variableを除いて、tf.function はすべての出力を返す必要があります。戻り値を使用せずに関数からテンソルに直接アクセスしようとすると、「リーク」が発生します。
たとえば、以下の関数は、Python グローバル x を介してテンソル a を「リーク」します。
リークされた値も返される場合でもリークします。
通常、このようなリークは、Python ステートメントまたはデータ構造を使用するときに発生します。アクセスできないテンソルがリークするだけでなく、このようなステートメントは Python の副作用としてカウントされ、すべての関数呼び出しで実行されないことがあるため、間違っている可能性があります。
また、一般的に外部 Python コレクションまたはオブジェクトの変更によりローカルテンソルがリークすることもあります。
再帰的な tf.function はサポートされていない
再帰的な Function はサポートされていないので、無限ループを引き起こす可能性があります。以下に例を示します。
再帰的な Function が正しく動作しているように見えても、Python 関数は複数回トレースされ、パフォーマンスに影響を与える可能性があります。以下に例を示します。
既知の問題
Function が正しく評価していない場合、以下の既知の問題が該当する可能性があります。これらの問題は、今後修正される予定です。
Python のグローバル変数と自由変数への依存
Function は、Python 引数の新しい値で呼び出された時に新しい ConcreteFunction を作成しますが、Python クロージャ、グローバル変数、またはその Function の非ローカル変数に対しては作成しません。Function への呼び出しごとに値が変化する場合でも、Function はトレーシングされたときの値をそのまま使用してしまいます。これは、通常の Python 関数の動作とは異なります。
このため、外側の名前を閉じる代わりに引数を使用する関数プログラミングの様式をお勧めします。
グローバル値を更新する別の方法として、それを tf.Variable にし、代わりに Variable.assign メソッドを使用することができます。
Python オブジェクトへの依存
カスタム Python オブジェクトを引数として tf.function に渡すことはサポートされていますが、ある制限が伴います。
特徴量を最大限にカバーするには、オブジェクトを tf.function に渡す前に拡張型に変換することを検討してください。Python プリミティブ型と tf.nest 対応構造も使用できます。
ただし、トレーシングのルールで説明されるように、カスタム TraceType がカスタム Python クラスによって提供されない場合、tf.function はインスタンスベースの等価性を使用するように強制されてしまいます。そのため、変更された属性と同じオブジェクトを渡しても、新しいトレースは作成されません。
変更されたモデルインスタンスの評価に同じ Function を使用すると、元のモデルと同じインスタンスに基づく TraceTypeがあるために不具合が生じます。
そのため、ミュータブルオブジェクト属性に依存しない Function を書くか、そのような属性を Function に伝達するオブジェクトにトレーシングプロトコルを実装することをお勧めします。
この方法が困難な場合は、回避策として、オブジェクトを変更するたびに新しい Function がリトレーシングを行うようにする方法が挙げられます。
リトレーシングにはコストがかかるため、tf.Variable をオブジェクト属性として使用することができます。こうすることで、リトレーシングを行わずに、ミュートして(変更はしません)同様の効果を得ることができます。
tf.Variables の作成
Function は、最初の呼び出しで 1 回作成され、後続の関数呼び出しで再利用されるシングルトン tf.Variable のみをサポートします。以下のコードスニペットは、すべての関数呼び出しで新しい tf.Variable を作成します。これにより、ValueError 例外が発生します。
例:
この制限を回避するために使用される一般的なパターンは、Python None 値で開始し、値が None の場合は条件付きで tf.Variable を作成することです。
複数の Keras オプティマイザとの使用
2 つ以上の Keras オプティマイザを tf.function で使用しようとすると、「ValueError: tf.function only supports singleton tf.Variables created on the first call. 」というエラーが発生することがあります。このエラーは、オプティマイザが初めて勾配を適用する際に、内部的に tf.Variables を作成するために発生するものです。
トレーニング中にオプティマイザを変更する必要がある場合は、回避策として、オプティマイザごとに新しい Function を作成し、ConcreteFunction を直接呼び出すようにすることができます。
複数の Keras モデルとの使用
また、別のモデルインスタンスを同一の Function に渡す際に、「ValueError: tf.function only supports singleton tf.Variables created on the first call.」というエラーも発生することがあります。
このエラーは、Keras モデル(入力形状が定義されていない)と Keras レイヤーが、初めて呼び出されるときに tf.Variables を作成するために発生するものです。これらの変数をすでに呼び出された Function 内で初期化しようとしているのでしょう。このエラーを回避するには、モデルをトレーニングする前に、model.build(input_shape) を呼び出して、すべての重みを初期化するようにしてください。
参考資料
Function のエクスポートと読み込みの方法については、SavedModel ガイドをご覧ください。トレーシングの後に実行するグラフの最適化については、Grappler ガイドをご覧ください。データパイプラインの最適化方法とモデルのプロファイリングについては、Profiler ガイドをご覧ください。
TensorFlow.org で表示
Google Colab で実行
GitHub でソースを表示
ノートブックをダウンロード