Path: blob/master/site/ja/probability/examples/Distributed_Inference_with_JAX.ipynb
25118 views
Copyright 2020 The TensorFlow Probability Authors.
Licensed under the Apache License, Version 2.0 (the "License");
JAX の TensorFlow Probability (TFP) に、分散数値計算用のツールが追加されました。多数のアクセラレータに拡張するために、ツールは「単一プログラム複数データ」パラダイム(SPMD)を使用してコードを記述することを中心に構築されています。
このノートブックでは、「SPMD で考える」方法を説明し、TPU ポッドや GPU のクラスタなどの構成にスケーリングするための新しい TFP 抽象化を紹介します。このコードを自分で実行する場合は、必ず TPU ランタイムを選択してください。
まず、最新バージョンの TFP、JAX、TF をインストールします。
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
ERROR: tf-nightly-cpu 2.6.0.dev20210401 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement gast==0.3.3, but you'll have gast 0.4.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement grpcio~=1.32.0, but you'll have grpcio 1.34.1 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement h5py~=2.10.0, but you'll have h5py 3.1.0 which is incompatible.
ERROR: tensorflow 2.4.1 has requirement numpy~=1.19.2, but you'll have numpy 1.20.2 which is incompatible.
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.25.1 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
いくつかの JAX ユーティリティと一般的なライブラリをインポートします。
また、いくつかの便利な TFP エイリアスを設定します。新しい抽象化は現在、tfp.experimental.distribute
と tfp.experimental.mcmc
で提供されています。
ノートブックを TPU に接続するには、JAX の次のヘルパーを使用します。接続されていることを確認するために、デバイスの数を出力します。これは 8 である必要があります。
jax.pmap
の簡単な紹介
TPU に接続すると、8 台のデバイスにアクセスできるようになります。ただし、JAXコードを eager に実行すると、JAX はデフォルトで 1 台だけで計算を実行します。
多くのデバイス間で計算を実行する最も簡単な方法は、関数をマップし、各デバイスにマップの 1 つのインデックスを実行させることです。JAX は、関数を複数のデバイスにマップする関数に変換する jax.pmap
(「並列マップ」)変換を提供します。
次の例では、サイズ 8 の配列を作成し(使用可能なデバイスの数に一致させるため)、それに 5 を追加する関数をマップします。
ShardedDeviceArray
型が返され、出力配列がデバイス間で物理的に分割されていることを示していることに注目してください。
jax.pmap
は意味的にマップのように機能しますが、その動作を変更するいくつかの重要なオプションがあります。デフォルトでは、pmap
は関数へのすべての入力がマップされていると想定していますが、in_axes
引数を使用してこの動作を変更できます。
同様に、pmap
の out_axes
引数は、すべてのデバイスで値を返すかどうかを決定します。out_axes
を None
に設定すると、最初のデバイスの値が自動的に返されます。値がすべてのデバイスで同じであると確信できる場合にのみ使用してください。
実行したいことがマップされた純粋関数として簡単に表現できない場合はどうなるでしょうか。たとえば、マッピングしている軸全体で合計を計算したい場合はどうなるでしょうか。JAX は、デバイス間で通信する「集合」機能を提供し、より興味深く複雑な分散プログラムを作成できるようにします。それらがどのように機能するかを説明するために、SPMD を紹介します。
SPMD とは
シングルプログラムマルチデータ(SPMD)は、単一のプログラム(同一コード)がデバイス間で同時に実行される並行プログラミングモデルですが、実行中の各プログラムへの入力は異なる場合があります。
プログラムが入力の単純な関数である場合(x + 5
など)、SPMD でプログラムを実行すると、前に jax.pmap
で行ったように、プログラムをさまざまなデータにマッピングするだけです。 ただし、関数を単に「マップ」するだけでなく、JAX は、デバイス間で通信する関数である「集合」を提供します。
たとえば、すべてのデバイスの数量の合計を取得する場合、 まず pmap
でマッピングする軸に名前を割り当てる必要があります。次に、lax.psum
( 「parallel sum」)関数を使用してデバイス間で合計を実行し、合計する名前付き軸を確実に識別します。
psum
集合は、各デバイスの x
の値を集約し、その値をマップ全体で同期します。つまり、out
は各デバイスで 28.
です。単純な「マップ」ではなく、SPMD プログラムを実行しています。SPMD プログラムでは、集合体を使用する方法は限られていますが、各デバイスの計算が他のデバイスの同じ計算と相互作用できるようになります。このシナリオでは、psum
が値を同期するため、out_axes = None
を使用できます。
SPMD を使用すると、任意の TPU 構成のすべてのデバイスで同時に実行される 1 つのプログラムを作成できます。8 つの TPU コアで機械学習を行うために使用するコードを、数百から数千のコアを持つ TPU ポッドで使用できます。jax.pmap
と SPMD の詳細なチュートリアルについては、JAX 101 チュートリアルを参照してください。
大規模な MCMC
このノートブックでは、ベイズ推定にマルコフ連鎖モンテカルロ(MCMC)法を使用することに焦点を当てています。MCMC に多くのデバイスを利用する方法はいくつかありますが、このノートブックでは、次の 2 つに焦点を当てます。
異なるデバイスで独立したマルコフ連鎖を実行します。このケースは非常に単純で、バニラ TFP で行うことができます。
デバイス間でデータセットをシャーディングします。このケースはもう少し複雑で、最近追加された TFP 機能が必要です。
独立した連鎖
MCMC を使用して問題についてベイズ推定を行い、複数のデバイス間で複数のチェーンを並列に実行したいとします(たとえば、各デバイスで 2 つ)。これは、デバイス間で「マッピング」できるプログラム、つまり集合体を必要としないプログラムです。各プログラムが(同じマルコフ連鎖を実行するのではなく)異なるマルコフ連鎖を実行することを確認するために、各デバイスに異なる値のランダムシードを渡します。
2 次元ガウス分布からサンプリングするトイプロブレムで試してみましょう。TFP の既存の MCMC 機能をそのまま使用できます。一般に、マップされた関数内にほとんどのロジックを配置して、すべてのデバイスで実行されているものと最初のデバイスだけで実行されているものをより明確に区別します。
run
関数は、それ自体でステートレスランダムシードを取り込みます(ステートレスランダム性がどのように機能するかを確認するには、 JAX で TFP を使用する ノートブック および JAX 101 チュートリアルを参照してください。異なるシードに run
をマッピングすると、複数の独立したマルコフ連鎖が実行されます。
各デバイスに対応する追加の軸があることに注意してください。次元を並べ替えて平坦化し、16 の連鎖の軸を取得できます。
多くのデバイスで独立したチェーンを実行するのは、 tfp.mcmc
を使用する関数で pmap
を実行するだけで、各デバイスにランダムシードとして異なる値を渡すことができます。
データのシャーディング
MCMC では、多くの場合、ターゲットのディストリビューションはデータセットの条件付けによって取得された事後ディストリビューションであり、正規化されていない対数密度の計算には、観測された各データの尤度の合計が含まれます。
データセットが非常に大きい場合、1 台のデバイスで 1 つのチェーンを実行する場合にでも非常にコストがかかる可能性がありますが、複数のデバイスにアクセスできる場合は、データセットをデバイス間で分割して、利用可能なコンピューティングをより有効に活用できます。
シャーディングされたデータセットを使用して MCMC を実行する場合は、各デバイスで計算する非正規化対数密度が合計、つまりすべてのデータの密度を表すようにする必要があります。そうしないと、各デバイスは独自の誤ったターゲットディストリビューションで MCMC を実行します。そのため、TFP には、「シャーディングされた」対数確率の計算とそれらを使用した MCMC の実行を可能にする新しいツール(tfp.experimental.distribute
と tfp.experimental.mcmc
)があります。
シャーディングされたディストリビューション
TFP がシャーディングされた対数確率の計算に提供するコア抽象化は、Sharded
メタディストリビューションです。これは、入力としてディストリビューションを受け取り、SPMD コンテキストで実行されると特定のプロパティを持つ新しいディストリビューションを返します。Sharded
は tfp.experimental.distribute
にあります。
直感的には、Sharded
ディストリビューションは、デバイス間で「分割」された確率変数のセットに対応します。各デバイスで、それらは異なるサンプルを生成し、個別に異なる対数密度を持つことができます。あるいは、Sharded
ディストリビューションは、グラフィカルモデル用語の「プレート」に対応します。プレートサイズはデバイス数です。
Sharded
ディストリビューションのサンプリング
各デバイスで同じシードを使用してpmap
を実行しているプログラムで、Normal
ディストリビューションからサンプリングすると、各デバイスで同じサンプルが取得されます。次の関数は、デバイス間で同期される単一の確率変数をサンプリングするものと考えることができます。
tfd.Normal(0., 1.)
を tfed.Sharded
でラップすると、論理的に 8 つの異なる確率変数(各デバイスに 1 つ)が存在するため、 同じシードを渡しても、それぞれに異なるサンプルを生成します。
単一のデバイスでのこの分布の同等の表現は、8 つの独立した正規サンプルです。サンプルの値は異なりますが (tfed.Sharded
は疑似乱数の生成をわずかに異なる方法で行います。どちらも同じディストリビューションを表します。
Sharded
ディストリビューションの対数密度を取得する
SPMD コンテキストで正規分布からサンプルの対数密度を計算するとどうなるか見てみましょう。
各サンプルは各デバイスで同じであるため、各デバイスで同じ密度を計算します。直感的には、ここでは、ディストリビューションは単一の正規分布変数にのみあります。
Sharded
ディストリビューションでは、8 つの確率変数にわたるディストリビューションがあるため、サンプルのlog_prob
を計算するときに、デバイス間で個々の対数密度をそれぞれ合計します。(この合計 log_prob 値は、上記で計算された単一の log_prob よりも大きいことに気付くかもしれません。)
同等の「シャーディングされていない」ディストリビューションは、同じ対数密度を生成します。
Sharded
ディストリビューションは、各デバイスの sample
とは異なる値を生成しますが、各デバイスの log_prob
に対して同じ値を取得します。何が起こっているのでしょうか? Sharded
ディストリビューションは、psum
を内部的に実行して、log_prob
の値がデバイス間で同期していることを確認します。 なぜこの動作が必要なのでしょうか?各デバイスで 同じ MCMC チェーンを実行している場合、計算の確率変数がデバイス間でシャーディングされている場合でも、target_log_prob
を各デバイスで同じにする必要があります。
さらに、Sharded
ディストリビューションは、デバイス間の勾配が正しいことを保証し、遷移関数の一部として対数密度関数の勾配をとる HMC のようなアルゴリズムが適切なサンプルを生成することを保証します。
シャードされた JointDistribution
JointDistribution
(JD)を使用して、複数の Sharded
確率変数を持つモデルを作成できます。残念ながら、Sharded
ディストリビューションは、バニラ tfd.JointDistribution
で安全に使用できませんが、tfp.experimental.distribute
は、Sharded
ディストリビューションのように動作する「パッチが適用された」JD をエクスポートします。
これらのシャーディングされた JD は、コンポーネントとして Sharded
とバニラ TFP 分布の両方を持つことができます。シャーディングされていないディストリビューションの場合、各デバイスで同じサンプルを取得し、シャーディングされた分布の場合、異なるサンプルを取得します。各デバイスの log_prob
も同期されます。
MCMC での Sharded
ディストリビューション
MCMC のコンテキストでは、Sharded
ディストリビューションについてどのように考えればよいでしょうか?JointDistribution
として表現できる生成モデルがある場合、そのモデルの軸を選択して「シャーディング」することができます。通常、モデル内の 1 つの確率変数は観測データに対応し、デバイス間でシャーディングする大きなデータセットがある場合は、データポイントに関連付けられている変数もシャーディングする必要があります。また、シャーディングしている観測値と 1 対 1 の「ローカル」確率変数がある可能性があるため、これらの確率変数をさらにシャーディングする必要があります。
このセクションでは、TFP MCMC での Sharded
ディストリビューションの使用例について説明します。 distribute
ライブラリのいくつかのユースケースを示すために、単純なベイズロジスティック回帰の例と行列因数分解の例を見ていきます。
例: MNIST に対するベイズロジスティック回帰
大規模なデータセットに対してベイズロジスティック回帰を実行します。モデルには、回帰の重みよりも前の があり、 合計同時対数密度を取得するためのすべてのデータ で合計される尤度 があります。データをシャーディングする場合、モデルで観測された確率変数 と をシャーディングします。
MNIST の分類には次のベイズロジスティック回帰モデルを使用します。
TensorFlow データセットを使用して MNIST を読み込みましょう。
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
60000 のトレーニング画像がありますが、利用可能な 8 つのコアを利用して、8 つの方法で分割します。次の便利な shard
ユーティリティ関数を使用します。
先に進む前に、TPU の精度と HMC への影響について簡単に説明しましょう。TPUは、速度のために低い bfloat16
精度を使用して行列の乗算を実行します。bfloat16
行列の乗算は、頻繁に多くの深層学習アプリケーションに対して十分ですが、HMC で使用すると、精度が低いと軌道が発散し、拒否が発生する可能性があることが経験的に明らかになっています。追加の計算をいくらか犠牲にして、より高精度の行列乗算を使用できます。
matmul の精度を上げるには、"tensorfloat32"
の精度で jax.default_matmul_precision
デコレータを使用します(さらに精度を上げるには、"float32"
の精度を使用できます。)
次に、run
関数を定義します。この関数は、ランダムシード(各デバイスで同じになります)と MNIST のシャードを取り込みます。この関数は前述のモデルを実装し、TFP のバニラ MCMC 機能を使用して単一のチェーンを実行します。run
に jax.default_matmul_precision
デコレータを使用して、行列の乗算がより高い精度で実行されるようにします。ただし、以下の特定の例では、jnp.dot(images, w, precision=lax.Precision.HIGH)
を使用することもできます。
jax.pmap
には JIT コンパイルが含まれていますが、コンパイルされた関数は最初の呼び出し後にキャッシュされます。 run
を呼び出し、出力を無視してコンパイルをキャッシュします。
run
を再度呼び出して、実際の実行にかかる時間を確認します。
200,000 のリープフロッグステップを実行しています。各ステップは、データセット全体の勾配を計算します。計算を 8 コアに分割すると、約 95 秒で 200,000 エポックのトレーニングに相当する計算が可能になります。これは 1 秒あたり約 2,100 エポックです。
各サンプルの対数密度と各サンプルの精度をプロットしてみましょう。
サンプルをアンサンブルすると、ベイズモデル平均化を計算してパフォーマンスを向上させることができます。
ベイズモデル平均化により、精度がほぼ 1 %向上します。
例: MovieLens 推奨システム
次に、ユーザーによる映画の評価が含まれた、MovieLens 推奨データセットを使用して推論を試してみます。具体的には、MovieLens を watch matrix として表すことができます。ここで、 はユーザー数、 は映画数です。 を期待します。 のエントリは、ユーザー が映画 を視聴したかどうかを示すブール値です。MovieLens はユーザー評価を提供しますが、問題を単純化するためにそれらを無視していることに注意してください。
まず、データセットを読み込みます。100 万の評価があるバージョンを使用します。
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.
データセットの前処理を行って、視聴行列 を取得します。
単純な確率的行列因数分解モデルを使用して、 の生成モデルを定義します。潜在的な ユーザー行列 と潜在的な を想定します。これらを乗算すると、視聴行列 のベルヌーイのロジットが生成されます。また、ユーザーと映画のバイアスベクトル、 と も含まれます。
これはかなり大きな行列です。6040 のユーザーと 3706 の映画は、2,200 万を超えるエントリを含む行列につながります。このモデルのシャーディングにどのようにアプローチすればよいでしょうか。(つまり、映画よりもユーザーが多い)と仮定すると、ユーザー軸全体で視聴行列をシャーディングするのが理にかなっているため、各デバイスにはサブセットに対応する視聴行列のチャンクがあります。ただし、前の例とは異なり、 行列もシャーディングする必要があります。これは、ユーザーごとに埋め込みがあるため、各デバイスが のシャードと をシャーディングするためです。 一方、 はシャーディングされず、デバイス間で同期されます。
run
を作成する前に、ローカル確率変数 をシャーディングする際の追加の課題について簡単に説明します。HMC を実行している場合、バニラ tfp.mcmc.HamiltonianMonteCarlo
カーネルは、チェーンの状態の各要素の運動量のサンプリングを実行します。以前は、シャーディングされていない確率変数のみがその状態の一部であり、運動量は各デバイスで同じでした。 がシャーディングされたら、 で同じ運動量をサンプリングしながら、各デバイスで で異なる運動量をサンプリングする必要があります。そのためには、 tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo
と Sharded
運動量ディストリビューションを使用できます。 並列計算をファーストクラスにし続けると、これを単純化することができます。つまり、HMC カーネルにシャードネスインジケーターを含めます。
コンパイルされた run
をキャッシュするために、もう 1 回実行します。
次に、コンパイルのオーバーヘッドなしで再度実行します。
約 3 分で約 150,000 のリープフロッグステップを完了しました。つまり、1 秒あたり約 83 のリープフロッグステップです。サンプルの受け入れ率と対数密度をプロットします。
マルコフ連鎖のサンプルがいくつかあるので、それらを使用していくつかの予測を行います。まず、各コンポーネントを抽出します。user_embeddings
と user_bias
はデバイス間で分割されるため、ShardedArray
を連結してすべてを取得する必要があることに注意してください。一方、movie_embeddings
と movie_bias
はすべてのデバイスで同じであるため、最初のシャードから値を選択するだけです。 通常の numpy
を使用して、TPU から CPU に値をコピーします。
これらのサンプルでキャプチャされた不確実性を利用する単純な推薦システムを構築してみましょう。まず、視聴確率に応じて映画をランク付けする関数を作成します。
これで、すべてのサンプルをループし、各サンプルについて、ユーザーがまだ視聴していない上位の映画を選択する関数を作成できます。 次に、サンプル全体で推奨されるすべての映画の数を確認できます。
映画の視聴数が最も多いユーザーと最も少ないユーザーを比較します。
user_most
が視聴する可能性が高い映画の種類に関する情報が多いので、システムは user_least
よりも user_most
についてより確実であるはずです。
user_least
の推奨事項には、視聴の好みにおける追加的な不確実性を反映され、より多くの差異があることがわかります。
また、推薦される映画のジャンルも見てみます。
user_most
は多くの映画を視聴しているので、ミステリーや犯罪などのよりニッチなジャンルが推奨されていますが、user_least
は多くの映画を視聴していないので、コメディやアクションの主流の映画が推奨されています。