Path: blob/master/site/ja/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Copyright 2021 The TensorFlow Federated Authors.
federated_select
とスパースな集約によるクライアント効率の高い大規模なモデルの連合学習
このチュートリアルでは、TFF で tff.federated_select
とスパースな集約を使用して、非常に大規模なモデルをトレーニングする方法を実演します。各クライアントデバイスはモデルのごく一部のみをダウンロードおよび更新します。このチュートリアルは自己完結型ですが、ここで使用されるいくつかのテクニックの基本的な説明については、tff.federated_select
チュートリアルとカスタム連合学習アルゴリズムチュートリアルを参照してください。
このチュートリアルでは、マルチラベル分類のロジスティック回帰を検討し、bag-of-words の特徴量表現に基づいて、どの「タグ」がテキスト文字列に関連付けられているかを予測します。重要なのは、通信とクライアント側の計算コストは固定定数 (MAX_TOKENS_SELECTED_PER_CLIENT
) によって制御され、全体の語彙サイズ(実際の設定では非常に大きくなる可能性があります)に合わせてスケーリングされないということです。
各クライアントは、最大でこの数の一意のトークンのモデルの重みの行を federated_select
します。これにより、クライアントのローカルモデルのサイズと、実行されるサーバー->クライアント (federated_select
) およびクライアント->サーバー (federated_aggregate
) 通信の量が制限されます。
これを 1(各クライアントからのすべてのトークンが選択されていないことを確認)または大きな値に設定しても、このチュートリアルは正しく実行されますが、モデルの収束が影響を受ける可能性があります。
また、さまざまな型の定数をいくつか定義します。このコラボの場合、トークンは、データセットを解析した後の特定の単語の整数識別子です。
問題の設定: データセットとモデル
このチュートリアルでは、簡単に実験できるように小規模なトイデータセットを作成します。ただし、データセットの形式は Federated StackOverflow と互換性があり、前処理とモデルアーキテクチャは、適応型連合最適化 の StackOverflow タグ予測問題から採用されています。
データセットの解析と前処理
小規模なトイデータセット
12 語のグローバル語彙と 3 つのクライアントを使用して、小規模なトイデータセットを作成します。この小規模なサンプルは、エッジケースのテスト(たとえば、MAX_TOKENS_SELECTED_PER_CLIENT = 6
未満の個別のトークンを持つ 2 つのクライアントと、それ以上のトークンを持つ 1 つのクライアント)およびコードの開発に役立ちます。
ただし、このアプローチの実際のユースケースでは、数千万以上のグローバル語彙であり、各クライアントに数千の異なるトークンが表示される可能性があります。データの形式が同じであるため、より現実的なテストベッドの問題への拡張(tff.simulation.datasets.stackoverflow.load_data()
データセット)は簡単です。
まず、単語とタグの語彙を定義します。
ここで、小規模なローカルデータセットをもつ 3 つのクライアントを作成します。このチュートリアルを colab で実行している場合は、以下で開発した関数の出力を解釈/確認するために、「タブ内のセルのミラーリング」機能を使用してこのセルとその出力を固定すると便利な場合があります。
入力フィーチャ(トークン/単語)とラベル(ポストタグ)の生の数の定数を定義します。OOV トークン/タグを追加するため、実際の入出力スペースは NUM_OOV_BUCKETS = 1
大きくなります。
データセットのバッチバージョンと個々のバッチを作成します。これは、コードのテストに役立ちます。
スパース入力でモデルを定義する
タグごとに単純な独立ロジスティック回帰モデルを使用します。
まず、予測を行って、それが機能することを確認しましょう。
そして、いくつかの簡単な集中トレーニングを実行します。
連合計算のビルディングブロック
連合平均化アルゴリズムの単純なバージョンを実装しますが、重要な違いは、各デバイスはモデルの関連するサブセットのみをダウンロードし、そのサブセットへの更新のみを提供するということです。
MAX_TOKENS_SELECTED_PER_CLIENT
の省略形として M
を使用します。上位レベルでは、1 ラウンドのトレーニングには次の手順が含まれます。
参加している各クライアントは、ローカルデータセットをスキャンし、入力文字列を解析して、正しいトークン(int インデックス)にマッピングします。これには、グローバル(大規模)ディクショナリへのアクセスが必要です(これは、機能ハッシュ手法を使用して回避できる可能性があります)。次に、各トークンが発生する回数をスパースにカウントします。
U
の一意のトークンがデバイスで発生する場合、トレーニングするnum_actual_tokens = min(U, M)
の最も頻繁なトークンを選択します。クライアントは
federated_select
を使用して、サーバーからnum_actual_tokens
で選択されたトークンのモデル係数を取得します。各モデルスライスは形状(TAG_VOCAB_SIZE, )
のテンソルであるため、クライアントに送信されるデータの合計は最大でサイズTAG_VOCAB_SIZE * M
になります(以下の注意事項を参照)。クライアントは、マッピング
global_token -> local_token
を作成します。ローカルトークン(int index)は、選択されたトークンのリスト内のグローバルトークンのインデックスです。クライアントは、範囲
[0, num_actual_tokens)
から最大M
トークンの係数のみを持つグローバルモデルの「小規模な」バージョンを使用します。global -> local
マッピングは、選択したモデルスライスからこのモデルの密なパラメータを初期化するために使用されます。クライアントは、
global -> local
マッピングで前処理されたデータに対して SGD を使用してローカルモデルをトレーニングします。クライアントは、
local -> global
マッピングを使用して行にインデックスを付けることにより、ローカルモデルのパラメータをIndexedSlices
更新に変換します。サーバーは、スパースな集約を使用してこれらの更新を集約します。サーバーは、上記の集約の(密な)結果を取得し、それを参加しているクライアントの数で除算し、結果の平均更新をグローバルモデルに適用します。
このセクションでは、これらのステップの構成要素を構築します。これらの構成要素は、1 つ のトレーニングラウンドの完全なロジックをキャプチャする最終的な federated_computation
に結合されます。
注意: 上記に説明されていない 1 つの技術的な詳細があります。
federated_select
とローカルモデルの構築では、静的に既知の形状が必要であるため、動的なクライアントごとのnum_actual_tokens
サイズを使用できません。代わりに、静的な値M
を使用し、必要に応じてパディングを追加します。これは、アルゴリズムのセマンティクスには影響しません。
クライアントトークンをカウントし、federated_select
にスライスするモデルを決定する
各デバイスは、モデルのどの「スライス」がローカルトレーニングデータセットに関連しているかを判断する必要があります。ここでは、クライアントトレーニングデータセットの各トークンを含むサンプルの数を(スパースに)カウントします。
デバイスで最も頻繁に発生する MAX_TOKENS_SELECTED_PER_CLIENT
トークンに対応するモデルパラメータを選択します。デバイスで発生するトークンの数がこれより少ない場合は、リストを埋めて federated_select
を使用できるようにします。
(発生確率に基づいて)トークンをランダムに選択するなど、他の戦略の方がおそらく優れていることに注意してください。これにより、(クライアントがデータを持っている)モデルのすべてのスライスが更新されます。
グローバルトークンをローカルトークンにマップする
上記の選択により、オンデバイスモデルに使用する [0, actual_num_tokens)
の範囲の密なトークンのセットが得られます。ただし、読み取ったデータセットには、はるかに大きなグローバル語彙範囲 [0, WORD_VOCAB_SIZE)
のトークンが含まれています。
したがって、グローバルトークンを対応するローカルトークンにマップする必要があります。ローカルトークン ID は、前の手順で計算された selected_tokens
テンソルへのインデックスによって与えられます。
各クライアントでローカル(サブ)モデルをトレーニングする
注意: federated_select
は、選択したスライスを、選択キーと同じ順序で tf.data.Dataset
として返します。したがって、最初に、そのようなデータセットを取得し、それをクライアントモデルのモデルの重みとして使用できる単一の密なテンソルに変換する効用関数を定義します。
各クライアントで実行される単純なローカルトレーニングループを定義するために必要なすべてのコンポーネントが揃いました。
IndexedSlices を集約する
tff.federated_aggregate
を使用して、IndexedSlices
のスパースな連合集約を作成します。この単純な実装には、density_shape
が事前に静的に認識されているという制約があります。また、この集約は、セミスパース(クライアント->サーバー通信がスパース)ですが、サーバーは accumulate
およびmerge
で集約の密な表現を維持し、この密な表現を出力していることにも注意してください。
テストとして最小限の federated_computation
を作成します
federated_computation
に全てをまとめる
TFF を使用して、コンポーネントを tff.federated_computation
にまとめます。
連合平均化に基づく基本的なサーバートレーニング関数を使用し、サーバー学習率 1.0 で更新を適用します。クライアント提供のモデルを単純に平均化するのではなく、モデルに更新(デルタ)を適用することが重要です。そうでなければ、モデルの特定のスライスが特定のラウンドでどのクライアントによってもトレーニングされていない場合、その係数はゼロになる可能性があります。
さらにいくつかの tff.tf_computation
の要素が必要です。
以上ですべての要素をまとめる準備が整いました。
モデルをトレーニングしましょう
トレーニング関数ができたので、試してみましょう。