Path: blob/master/site/ja/federated/tutorials/custom_aggregators.ipynb
25118 views
Copyright 2021 The TensorFlow Federated Authors.
カスタム集約を実装する
このチュートリアルでは、tff.aggregators
モジュールのデザイン原理とクライアントからサーバーへの値のカスタム集約を実装するためのベストプラクティスについて説明します。
前提条件: このチュートリアルでは、配置(tff.SERVER
、tff.CLIENTS
)、TFF による計算の表現方法(tff.tf_computation
、tff.federated_computation
)、および型シグネチャといった Federated Core の基本概念に精通していることを前提としています。
デザインの概要
TFF では、「集約」は、tff.SERVER
で同じ型の集約値を生成するための tff.CLIENTS
における値セットの移動を指します。つまり、各クライアント値を利用できる必要はないということです。たとえば連合学習では、クライアントモデルの更新が平均化されて、サーバー上のグローバルモデルに適用される集約モデルの更新が取得されます。
TFF には、この目標を達成する tff.federated_sum
などの演算子のほかに、TFF には、集約計算の型シグネチャを形式化するため、単純な和よりも複雑な形式に一般化できる tff.templates.AggregationProcess
(ステートフルプロセス)が備わっています。
tff.aggregators
モジュールの主要コンポーネントは、AggregationProcess
を作成するファクトリです。これは、次の 2 つの面で、一般に有用で交換可能な TFF のビルディングブロックとなるように設計されています。
パラメータ化計算。 集約は、
tff.aggregators
と連携するように設計されたほかの TFF モデルに使用し、必要な集約をパラメータ化する独立したビルディングブロックです。
例:
集約の合成。 集約ビルディングブロックは、他の集約ビルディングブロックと合成してより複雑な合成集約を作成することができます。
例:
このチュートリアルの残りの部分では、これらの 2 つの目標をどのように達成するかを説明します。
集約プロセス
まず、tff.templates.AggregationProcess
を要約して、作成のためのファクトリパターンに従います。
tff.templates.AggregationProcess
は、集約向けに指定された型シグネチャを持つ tff.templates.MeasuredProcess
です。具体的には、initialize
と next
関数に、以下の型シグネチャがあります。
( -> state_type@SERVER)
(<state_type@SERVER, {value_type}@CLIENTS, *> -> <state_type@SERVER, value_type@SERVER, measurements_type@SERVER>)
状態(state_type
型)は、サーバーに配置する必要があります。next
関数は状態を入力引数として取り、状態と値が集約される引数(value_type
型)をクライアント側に配置します。*
はオプションの他の入力引数です。たとえば、重み付き平均の重みが該当します。これは、更新された状態オブジェクト、サーバーに配置された同じ型の集約値、およびいくつかの測定値を返します。
next
関数の実行間で渡される状態と、next
関数の特定の実行に応じて情報をレポートすることを目的にレポートされた測定値は空である場合があることに注意してください。いずれにせよ、これらは TFF の他の部分が従うことのできる明確なコントラクトを持つように明示的に指定されている必要があります。
tff.learning
でのモジュール更新といった他の TFF モジュールでは、tff.templates.AggregationProcess
を使用して値の集約方法をパラメータ化することが期待されています。ただし、実際に集約されたのがどの値であり、その型シグネチャが何であるかは、トレーニングされているモデルの詳細とそれを実行するために使用される学習アルゴリズムによって異なります。
集約を計算の他の側面から独立させるために、ファクトリパターンを使用します。集約されるオブジェクトの関連する型シグネチャが利用可能になったら、ファクトリの create
メソッドを呼び出して、適切な tff.templates.AggregationProcess
を作成します。したがって、集約プロセスを直接取り扱うのは、この作成を担当するライブラリ作成者のみということになります。
集約プロセスファクトリ
重みなしと重み付きの集約には、2 つの抽象ベースのファクトリクラスがあります。その create
メソッドは集約される値の型シグネチャを取り、その値の集約に使用する tff.templates.AggregationProcess
を返します。
tff.aggregators.UnweightedAggregationFactory
が作成するプロセスは、(1)サーバーでの状態と(2)指定した型 value_type
の値の 2 つの入力引数を取ります。
実装例は tff.aggregators.SumFactory
です。
tff.aggregators.WeightedAggregationFactory
が作成するプロセスは、(1)サーバーでの状態、(2)指定した型 value_type
の値、および(3)create
メソッドを呼び出したときにファクトリのユーザーが指定した型の重み weight_type
の 3 つの入力引数を取ります。
実装例は、重み付き平均を計算する tff.aggregators.MeanFactory
です。
ファクトリパターンは、上述の最初の目標の達成方法で、集計は独立したビルディングブロックです。たとえば、トレーニング対象のモデル変数を変更しても、複合集計は必ずしも変更する必要がありません。それを表現するファクトリは、tff.learning.algorithms.build_weighted_fed_avg
などのメソッドで使用される際に、別の型シグネチャで呼び出されることになります。
構成
一般的な集約プロセスは、(a)クライアントでの値の前処理、(b)クライアントからサーバーへの値の移動、および(c)サーバーでの集約値の後処理をカプセル化できることを思い出してください。上述の 2 つ目の目標である集計の複合は、集約ファクトリの実装を(b)が別の集約ファクトリにデリゲートできるように構成することで、tff.aggregators
モジュール内で実現されます。
この実装は、必要なロジックすべてを 1 つのファクトリクラスに実装する代わりに、デフォルトで集約に関連する 1 つの側面に焦点を当てています。必要であれば、このパターンによって、一度に 1 つずつビルディングブロックを入れ替えることが可能です。
例は、重み付きの tff.aggregators.MeanFactory
です。この実装は、クライアントで提供された値と重みを乗算し、重み付きの値を加算し、その和をサーバーの重みの和で除算します。tff.federated_sum
演算子を直接使用して合計を実装する代わりに、合計は、tff.aggregators.SumFactory
の 2 つのインスタンスにデリゲートされます。
このような構造によって、2 つのデフォルトの合計を別のファクトリに置き換えることが可能となり、したがって加算が異なります。たとえば、tff.aggregators.SecureSumFactory
、または tff.aggregators.UnweightedAggregationFactory
カスタム実装があります。逆に、平均化する前に値をクリッピングする場合は、tff.aggregators.MeanFactory
自体を、tff.aggregators.clipping_factory
などの別のファクトリの内部集約にすることができます。
tff.aggregators
モジュールの既存のファクトリを使用した合成メカニズムの推奨される使用方法については、前の「推奨される集約を学習向けにチューニングする」チュートリアルをご覧ください。
例によるベストプラクティス
タスクの単純な例を実装して、tff.aggregators
の概念を詳しく説明し、それを徐々に一般化していくことにします。もう一つの学習方法は、既存のファクトリの実装を確認することです。
value
を加算する代わりに、タスク例では、value * 2.0
を加算してから、その和を 2.0
で除算します。したがって、数学的に見れば、この集約結果は value
を直接加算したものと同じになります。この方法は、(1)クライアントでのスケーリング(2)クライアント間での加算(3)サーバーでのスケーリング解除の 3 部構成と考えることができます。
注意: このタスクは、必ずしも実用的とは言えませんが、いずれにしても、根底にある概念を説明する上で役立ちます。
ロジックは、上記で説明したデザインに従って、tff.aggregators.UnweightedAggregationFactory
のサブクラスとして実装されます。これにより、集約する value_type
が与えられると、適切な tff.templates.AggregationProcess
が作成されます。
最小限の実装
タスク例の場合、必要な計算は常に同じであるため、状態を使用する必要はありません。したがって、状態は空であり、tff.federated_value((), tff.SERVER)
として表現されます。現時点では、測定値についても同様です。
したがって、タスクの最小限の実装は、以下のようになります。
すべてが期待どおりに動作するかは、以下のコードで確認できます。
ステートフルネスと測定値
TFF では、反復的に実行されることが期待されており、イテレーションごとに変化する計算を表現するために、ステートフルネスが幅広く使用されています。たとえば、学習計算の状態には、学習されているモデルの重みが含まれます。
集約の計算で状態をどのように使用するかを説明するために、タスク例に変更を加えることにします。value
を 2.0
で乗算する代わりに、それをイテレーションのインデックス(集約が実行された回数)で乗算します。
これを行うには、イテレーションのインデックスを追跡する方法が必要です。これは、状態の概念を通じて実現することができます。initialize_fn
で、空の状態を作成する代わりに、状態がスカラーのゼロになるように初期化します。すると、状態を、(1)1.0
で増分、(2)value
の乗算に使用、(3)新しい更新済みの状態として返す、という 3 段階で、next_fn
で使用することができます。
これが完了したら、「それでも、上記とまったく同じコードを使って、すべての作業が期待どおりであるかを確認できます。本当に何かが変わったことをどうすれば知ることができるのでしょうか。」という疑問が湧くことでしょう。
良い質問です!ここで生きてくるのが、測定値の概念です。一般に、測定値は、next
関数の 1 回の実行に関連するすべての値をレポートするため、監視に使用することが可能です。この場合は、前の例の summed_value
の場合があります。つまり、「スケーリング解除」ステップの前の値であり、これはイテレーションのインデックスに依存していなければなりません。繰り返しになりますが、これは必ずしも実用的ではなく、関連するメカニズムを説明しているだけです。
したがって、タスクのステートフルな答えは以下のようになります。
next_fn
に入力として渡される state
は、サーバーに配置されていることに注意してください。これをクライアントで使用するにはまず、それを伝達する必要があります。これには、tff.federated_broadcast
演算子を使用します。
すべての作業が期待どおりであることを確認するには、レポートされた measurements
を確認することができます。これは、同じ client_data
を使って実行された場合であっても、実行ラウンドごとに異なります。
構造化型
連合学習でトレーニングされたモデルの重みは通常、単一のテンソルではなく、テンソルのコレクションで表現されます。TFF では、これは tff.StructType
として表現され、一般に有用な集約ファクトリであり、構造化型を受け入れられる必要があります。
ただし、上記の例では、tff.TensorType
オブジェクトしか操作していません。以前のファクトリを使用して、tff.StructType([(tf.float32, (2,)), (tf.float32, (3,))])
で集約プロセスを作成しようとすると、TensorFlow は tf.Tensor
と list
を乗算しようとするため、奇妙なエラーが発生してしまいます。
問題は、テンソルの構造を定数で乗算する代わりに、構造内の各テンソルを定数で乗算しなければならないということです。この問題は通常、作成された tff.tf_computation
の代わりに tf.nest
モジュールを使用することで解決します。
したがって、構造化型と互換性のある前のバージョンの ExampleTaskFactory
は、以下のようになります。
この例では、TFF コードを構造化する際に従うと便利なパターンが浮き彫りにされています。非常に単純な演算を扱っていないのであれば、tff.federated_computation
内でビルディングブロックとして使用される tff.tf_computation
を別の場所で作成すると、コードが読みやすくなります。tff.federated_computation
の中では、これらのビルディングブロックは固有の演算子を使用してのみ接続されます。
以下のようにして、期待どおりに動作するかを検証します。
内部集約
最後のステップでは、オプションとして、異なる集約方法を簡単に合成できるようにするために、実際の集約をほかのファクトリにデリゲートできるようにします。
これは、ExampleTaskFactory
のコンストラクタにオプションの inner_factory
引数を作成して行います。指定されていない場合は、tff.aggregators.SumFactory
が使用され、前のセクションで直接使用された tff.federated_sum
演算が適用されます。
create
が呼び出されると、まず、inner_factory
の create
を呼び出して、同じ value_type
を使用して内部集約プロセスを作成できます。
initialize_fn
が返すプロセスの状態は、「この」プロセスが作成する状態と今作成した内部プロセスの状態の 2 つ合成です。
next_fn
の実装は、実際の集約が内部プロセスの next
関数にデリゲートされていることと、最終出力の作成方法において異なります。状態はやはり「この」状態と「内部」状態で構成されており、測定値は OrderedDict
と同様の方法で作成されています。
以下は、そのようなパターンの実装です。
inner_process.next
関数にデリゲートする場合、取得する戻り値の構造は tff.templates.MeasuredProcessOutput
で、state
、result
、および measurements
の 3 つのフィールドが伴います。合成される集約プロセスの全体的な戻り値の構造を作成する場合、state
と measurements
フィールドは一般に、共に作成されて戻されます。対照的に、 result
フィールドは集約される値に対応し、代わりに合成された集約を「通過」します。
state
オブジェクトは、ファクトリの実装の詳細として見なされる必要があり、したがって、合成は任意の構造にすることができます。ただし、measurements
はある時点でユーザーにレポートされる値に対応します。したがって、OrderedDict
を使用することをお勧めします。この場合、レポートされたメトリックが合成のどこから来ているかが明確になるように、合成された名前を付けます。
tff.federated_zip
演算子の使用にも注意してください。作成されるプロセスで制御される state
オブジェクトは tff.FederatedType
でなければなりません。代わりに戻された (this_state, inner_state)
が initialize_fn
にある場合、戻り値の型シグネチャは 2 タプルの tff.FederatedType
を含む tff.StructType
となります。tff.federated_zip
を使用すると、tff.FederatedType
がトップレベルに「昇格」されます。これは、戻される状態と測定値を準備する際に next_fn
で同様に使用されます。
最後に、これがデフォルトの内部集約でどのように使用されるかを確認します。
... そして、別の内部集約で確認します。たとえば、 ExampleTaskFactory
を使用します。
まとめ
このチュートリアルでは、集約ファクトリとして表現される汎用の集約ビルディングブロックを作成するために従うベストプラクティスを説明しました。次の 2 つの方法で設計を意図することで、汎用性が得られます。
*パラメータ化計算。*集計は、
tff.aggregators
と連携して、tff.learning.algorithms.build_weighted_fed_avg
などの必要な集計をパラメータ化するように設計された他の TFF モジュールに使用できる独立したビルディングブロックです。集約の合成。 集約ビルディングブロックは、他の集約ビルディングブロックと合成してより複雑な合成集約を作成することができます。