Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/ja/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

federated_select とスパースな集約によるクライアント効率の高い大規模なモデルの連合学習

このチュートリアルでは、TFF で tff.federated_select とスパースな集約を使用して、非常に大規模なモデルをトレーニングする方法を実演します。各クライアントデバイスはモデルのごく一部のみをダウンロードおよび更新します。このチュートリアルは自己完結型ですが、ここで使用されるいくつかのテクニックの基本的な説明については、tff.federated_select チュートリアルカスタム連合学習アルゴリズムチュートリアルを参照してください。

このチュートリアルでは、マルチラベル分類のロジスティック回帰を検討し、bag-of-words の特徴量表現に基づいて、どの「タグ」がテキスト文字列に関連付けられているかを予測します。重要なのは、通信とクライアント側の計算コストは固定定数 (MAX_TOKENS_SELECTED_PER_CLIENT) によって制御され、全体の語彙サイズ(実際の設定では非常に大きくなる可能性があります)に合わせてスケーリングされないということです。

#@test {"skip": true} !pip install --quiet --upgrade tensorflow-federated
import collections from collections.abc import Callable import itertools import numpy as np import tensorflow as tf import tensorflow_federated as tff

各クライアントは、最大でこの数の一意のトークンのモデルの重みの行を federated_select します。これにより、クライアントのローカルモデルのサイズと、実行されるサーバー->クライアント (federated_select) およびクライアント->サーバー (federated_aggregate) 通信の量が制限されます。

これを 1(各クライアントからのすべてのトークンが選択されていないことを確認)または大きな値に設定しても、このチュートリアルは正しく実行されますが、モデルの収束が影響を受ける可能性があります。

MAX_TOKENS_SELECTED_PER_CLIENT = 6

また、さまざまな型の定数をいくつか定義します。このコラボの場合、トークンは、データセットを解析した後の特定の単語の整数識別子です。

# There are some constraints on types # here that will require some explicit type conversions: # - `tff.federated_select` requires int32 # - `tf.SparseTensor` requires int64 indices. TOKEN_DTYPE = tf.int64 SELECT_KEY_DTYPE = tf.int32 # Type for counts of token occurences. TOKEN_COUNT_DTYPE = tf.int32 # A sparse feature vector can be thought of as a map # from TOKEN_DTYPE to FEATURE_DTYPE. # Our features are {0, 1} indicators, so we could potentially # use tf.int8 as an optimization. FEATURE_DTYPE = tf.int32

問題の設定: データセットとモデル

このチュートリアルでは、簡単に実験できるように小規模なトイデータセットを作成します。ただし、データセットの形式は Federated StackOverflow と互換性があり、前処理モデルアーキテクチャは、適応型連合最適化 の StackOverflow タグ予測問題から採用されています。

データセットの解析と前処理

NUM_OOV_BUCKETS = 1 BatchType = collections.namedtuple('BatchType', ['tokens', 'tags']) def build_to_ids_fn(word_vocab: list[str], tag_vocab: list[str]) -> Callable[[tf.Tensor], tf.Tensor]: """Constructs a function mapping examples to sequences of token indices.""" word_table_values = np.arange(len(word_vocab), dtype=np.int64) word_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values), num_oov_buckets=NUM_OOV_BUCKETS) tag_table_values = np.arange(len(tag_vocab), dtype=np.int64) tag_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values), num_oov_buckets=NUM_OOV_BUCKETS) def to_ids(example): """Converts a Stack Overflow example to a bag-of-words/tags format.""" sentence = tf.strings.join([example['tokens'], example['title']], separator=' ') # We represent that label (output tags) densely. raw_tags = example['tags'] tags = tf.strings.split(raw_tags, sep='|') tags = tag_table.lookup(tags) tags, _ = tf.unique(tags) tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS) tags = tf.reduce_max(tags, axis=0) # We represent the features as a SparseTensor of {0, 1}s. words = tf.strings.split(sentence) tokens = word_table.lookup(words) tokens, _ = tf.unique(tokens) # Note: We could choose to use the word counts as the feature vector # instead of just {0, 1} values (see tf.unique_with_counts). tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1)) tokens_st = tf.SparseTensor( tokens, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,)) tokens_st = tf.sparse.reorder(tokens_st) return BatchType(tokens_st, tags) return to_ids
def build_preprocess_fn(word_vocab, tag_vocab): @tf.function def preprocess_fn(dataset): to_ids = build_to_ids_fn(word_vocab, tag_vocab) # We *don't* shuffle in order to make this colab deterministic for # easier testing and reproducibility. # But real-world training should use `.shuffle()`. return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return preprocess_fn

小規模なトイデータセット

12 語のグローバル語彙と 3 つのクライアントを使用して、小規模なトイデータセットを作成します。この小規模なサンプルは、エッジケースのテスト(たとえば、MAX_TOKENS_SELECTED_PER_CLIENT = 6 未満の個別のトークンを持つ 2 つのクライアントと、それ以上のトークンを持つ 1 つのクライアント)およびコードの開発に役立ちます。

ただし、このアプローチの実際のユースケースでは、数千万以上のグローバル語彙であり、各クライアントに数千の異なるトークンが表示される可能性があります。データの形式が同じであるため、より現実的なテストベッドの問題への拡張(tff.simulation.datasets.stackoverflow.load_data()データセット)は簡単です。

まず、単語とタグの語彙を定義します。

# Features FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi'] VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas'] FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon'] WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS # Labels TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

ここで、小規模なローカルデータセットをもつ 3 つのクライアントを作成します。このチュートリアルを colab で実行している場合は、以下で開発した関数の出力を解釈/確認するために、「タブ内のセルのミラーリング」機能を使用してこのセルとその出力を固定すると便利な場合があります。

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB) def make_dataset(raw): d = tf.data.Dataset.from_tensor_slices( # Matches the StackOverflow formatting collections.OrderedDict( tokens=tf.constant([t[0] for t in raw]), tags=tf.constant([t[1] for t in raw]), title=['' for _ in raw])) d = preprocess_fn(d) return d # 4 distinct tokens CLIENT1_DATASET = make_dataset([ ('apple orange apple orange', 'FRUIT'), ('carrot trout', 'VEGETABLE|FISH'), ('orange apple', 'FRUIT'), ('orange', 'ORANGE|CITRUS') # 2 OOV tag ]) # 6 distinct tokens CLIENT2_DATASET = make_dataset([ ('pear cod', 'FRUIT|FISH'), ('arugula peas', 'VEGETABLE'), ('kiwi pear', 'FRUIT'), ('sturgeon', 'FISH'), # OOV word ('sturgeon bass', 'FISH') # 2 OOV words ]) # A client with all possible words & tags (13 distinct tokens). # With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model # slices for all tokens that occur on this client. CLIENT3_DATASET = make_dataset([ (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)), # Mathe the OOV token and 'salmon' occur in the largest number # of examples on this client: ('salmon oovword', 'FISH|OOVTAG') ]) print('Word vocab') for i, word in enumerate(WORD_VOCAB): print(f'{i:2d} {word}') print('\nTag vocab') for i, tag in enumerate(TAG_VOCAB): print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH

入力フィーチャ(トークン/単語)とラベル(ポストタグ)の生の数の定数を定義します。OOV トークン/タグを追加するため、実際の入出力スペースは NUM_OOV_BUCKETS = 1 大きくなります。

NUM_WORDS = len(WORD_VOCAB) NUM_TAGS = len(TAG_VOCAB) WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

データセットのバッチバージョンと個々のバッチを作成します。これは、コードのテストに役立ちます。

batched_dataset1 = CLIENT1_DATASET.batch(2) batched_dataset2 = CLIENT2_DATASET.batch(3) batched_dataset3 = CLIENT3_DATASET.batch(2) batch1 = next(iter(batched_dataset1)) batch2 = next(iter(batched_dataset2)) batch3 = next(iter(batched_dataset3))

スパース入力でモデルを定義する

タグごとに単純な独立ロジスティック回帰モデルを使用します。

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int): model = tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True), tf.keras.layers.Dense( vocab_tags_size, activation='sigmoid', kernel_initializer=tf.keras.initializers.zeros, # For simplicity, don't use a bias vector; this means the model # is a single tensor, and we only need sparse aggregation of # the per-token slices of the model. Generalizing to also handle # other model weights that are fully updated # (non-dense broadcast and aggregate) would be a good exercise. use_bias=False), ]) return model

まず、予測を行って、それが機能することを確認しましょう。

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) p = model.predict(batch1.tokens) print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]

そして、いくつかの簡単な集中トレーニングを実行します。

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001), loss=tf.keras.losses.BinaryCrossentropy()) model.train_on_batch(batch1.tokens, batch1.tags)

連合計算のビルディングブロック

連合平均化アルゴリズムの単純なバージョンを実装しますが、重要な違いは、各デバイスはモデルの関連するサブセットのみをダウンロードし、そのサブセットへの更新のみを提供するということです。

MAX_TOKENS_SELECTED_PER_CLIENT の省略形として M を使用します。上位レベルでは、1 ラウンドのトレーニングには次の手順が含まれます。

  1. 参加している各クライアントは、ローカルデータセットをスキャンし、入力文字列を解析して、正しいトークン(int インデックス)にマッピングします。これには、グローバル(大規模)ディクショナリへのアクセスが必要です(これは、機能ハッシュ手法を使用して回避できる可能性があります)。次に、各トークンが発生する回数をスパースにカウントします。U の一意のトークンがデバイスで発生する場合、トレーニングする num_actual_tokens = min(U, M) の最も頻繁なトークンを選択します。

  2. クライアントは federated_select を使用して、サーバーから num_actual_tokens で選択されたトークンのモデル係数を取得します。各モデルスライスは形状 (TAG_VOCAB_SIZE, ) のテンソルであるため、クライアントに送信されるデータの合計は最大でサイズ TAG_VOCAB_SIZE * M になります(以下の注意事項を参照)。

  3. クライアントは、マッピング global_token -> local_token を作成します。ローカルトークン(int index)は、選択されたトークンのリスト内のグローバルトークンのインデックスです。

  4. クライアントは、範囲 [0, num_actual_tokens) から最大 M トークンの係数のみを持つグローバルモデルの「小規模な」バージョンを使用します。global -> local マッピングは、選択したモデルスライスからこのモデルの密なパラメータを初期化するために使用されます。

  5. クライアントは、global -> local マッピングで前処理されたデータに対して SGD を使用してローカルモデルをトレーニングします。

  6. クライアントは、local -> global マッピングを使用して行にインデックスを付けることにより、ローカルモデルのパラメータを IndexedSlices 更新に変換します。サーバーは、スパースな集約を使用してこれらの更新を集約します。

  7. サーバーは、上記の集約の(密な)結果を取得し、それを参加しているクライアントの数で除算し、結果の平均更新をグローバルモデルに適用します。

このセクションでは、これらのステップの構成要素を構築します。これらの構成要素は、1 つ のトレーニングラウンドの完全なロジックをキャプチャする最終的な federated_computation に結合されます。

注意: 上記に説明されていない 1 つの技術的な詳細があります。federated_select とローカルモデルの構築では、静的に既知の形状が必要であるため、動的なクライアントごとの num_actual_tokens サイズを使用できません。代わりに、静的な値 M を使用し、必要に応じてパディングを追加します。これは、アルゴリズムのセマンティクスには影響しません。

クライアントトークンをカウントし、federated_select にスライスするモデルを決定する

各デバイスは、モデルのどの「スライス」がローカルトレーニングデータセットに関連しているかを判断する必要があります。ここでは、クライアントトレーニングデータセットの各トークンを含むサンプルの数を(スパースに)カウントします。

@tf.function def token_count_fn(token_counts, batch): """Adds counts from `batch` to the running `token_counts` sum.""" # Sum across the batch dimension. flat_tokens = tf.sparse.reduce_sum( batch.tokens, axis=0, output_is_sparse=True) flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE) return tf.sparse.add(token_counts, flat_tokens)
# Simple tests # Create the initial zero token counts using empty tensors. initial_token_counts = tf.SparseTensor( indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE), values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = batched_dataset1.reduce(initial_token_counts, token_count_fn) tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy() print('tokens:', tokens) np.testing.assert_array_equal(tokens, [0, 1, 4, 8]) # The count is the number of *examples* in which the token/word # occurs, not the total number of occurences, since we still featurize # multiple occurences in the same example as a "1". counts = client_token_counts.values.numpy() print('counts:', counts) np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8] counts: [2 3 1 1]

デバイスで最も頻繁に発生する MAX_TOKENS_SELECTED_PER_CLIENT トークンに対応するモデルパラメータを選択します。デバイスで発生するトークンの数がこれより少ない場合は、リストを埋めて federated_select を使用できるようにします。

(発生確率に基づいて)トークンをランダムに選択するなど、他の戦略の方がおそらく優れていることに注意してください。これにより、(クライアントがデータを持っている)モデルのすべてのスライスが更新されます。

@tf.function def keys_for_client(client_dataset, max_tokens_per_client): """Computes a set of max_tokens_per_client keys.""" initial_token_counts = tf.SparseTensor( indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE), values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = client_dataset.reduce(initial_token_counts, token_count_fn) # Find the most-frequently occuring tokens tokens = tf.reshape(client_token_counts.indices, shape=(-1,)) counts = client_token_counts.values perm = tf.argsort(counts, direction='DESCENDING') tokens = tf.gather(tokens, perm) counts = tf.gather(counts, perm) num_raw_tokens = tf.shape(tokens)[0] actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens) selected_tokens = tokens[:actual_num_tokens] paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]] padded_tokens = tf.pad(selected_tokens, paddings=paddings) # Make sure the type is statically determined padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,)) # We will pass these tokens as keys into `federated_select`, which # requires SELECT_KEY_DTYPE=tf.int32 keys. padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE) return padded_tokens, actual_num_tokens
# Simple test # Case 1: actual_num_tokens > max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3) assert tf.size(selected_tokens) == 3 assert actual_num_tokens == 3 # Case 2: actual_num_tokens < max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10) assert tf.size(selected_tokens) == 10 assert actual_num_tokens == 4

グローバルトークンをローカルトークンにマップする

上記の選択により、オンデバイスモデルに使用する [0, actual_num_tokens) の範囲の密なトークンのセットが得られます。ただし、読み取ったデータセットには、はるかに大きなグローバル語彙範囲 [0, WORD_VOCAB_SIZE) のトークンが含まれています。

したがって、グローバルトークンを対応するローカルトークンにマップする必要があります。ローカルトークン ID は、前の手順で計算された selected_tokens テンソルへのインデックスによって与えられます。

@tf.function def map_to_local_token_ids(client_data, client_keys): global_to_local = tf.lookup.StaticHashTable( # Note int32 -> int64 maps are not supported tf.lookup.KeyValueTensorInitializer( keys=tf.cast(client_keys, dtype=TOKEN_DTYPE), # Note we need to use tf.shape, not the static # shape client_keys.shape[0] values=tf.range(0, limit=tf.shape(client_keys)[0], dtype=TOKEN_DTYPE)), # We use -1 for tokens that were not selected, which can occur for clients # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. # We will simply remove these invalid indices from the batch below. default_value=-1) def to_local_ids(sparse_tokens): indices_t = tf.transpose(sparse_tokens.indices) batch_indices = indices_t[0] # First column tokens = indices_t[1] # Second column tokens = tf.map_fn( lambda global_token_id: global_to_local.lookup(global_token_id), tokens) # Remove tokens that aren't actually available (looked up as -1): available_tokens = tokens >= 0 tokens = tokens[available_tokens] batch_indices = batch_indices[available_tokens] updated_indices = tf.transpose( tf.concat([[batch_indices], [tokens]], axis=0)) st = tf.sparse.SparseTensor( updated_indices, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), # Each client has at most MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. dense_shape=[sparse_tokens.dense_shape[0], MAX_TOKENS_SELECTED_PER_CLIENT]) st = tf.sparse.reorder(st) return st return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test client_keys, actual_num_tokens = keys_for_client( batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT) client_keys = client_keys[:actual_num_tokens] d = map_to_local_token_ids(batched_dataset3, client_keys) batch = next(iter(d)) all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1) # Confirm we have local indices in the range [0, MAX): assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT assert tf.math.reduce_max(all_tokens) >= 0

各クライアントでローカル(サブ)モデルをトレーニングする

注意: federated_select は、選択したスライスを、選択キーと同じ順序で tf.data.Datasetとして返します。したがって、最初に、そのようなデータセットを取得し、それをクライアントモデルのモデルの重みとして使用できる単一の密なテンソルに変換する効用関数を定義します。

@tf.function def slices_dataset_to_tensor(slices_dataset): """Convert a dataset of slices to a tensor.""" # Use batching to gather all of the slices into a single tensor. d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT, drop_remainder=False) iter_d = iter(d) tensor = next(iter_d) # Make sure we have consumed everything opt = iter_d.get_next_as_optional() tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY') return tensor
# Simple test weights = np.random.random( size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights) weights2 = slices_dataset_to_tensor(model_slices_as_dataset) np.testing.assert_array_equal(weights, weights2)

各クライアントで実行される単純なローカルトレーニングループを定義するために必要なすべてのコンポーネントが揃いました。

@tf.function def client_train_fn(model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens): initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset) assert len(model.trainable_variables) == 1 model.trainable_variables[0].assign(initial_model_weights) # Only keep the "real" (unpadded) keys. client_keys = client_keys[:actual_num_tokens] client_data = map_to_local_token_ids(client_data, client_keys) loss_fn = tf.keras.losses.BinaryCrossentropy() for features, labels in client_data: with tf.GradientTape() as tape: predictions = model(features) loss = loss_fn(labels, predictions) grads = tape.gradient(loss, model.trainable_variables) client_optimizer.apply_gradients(zip(grads, model.trainable_variables)) model_weights_delta = model.trainable_weights[0] - initial_model_weights model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], size=[actual_num_tokens, -1]) return client_keys, model_weights_delta
# Simple test # Note if you execute this cell a second time, you need to also re-execute # the preceeding cell to avoid "tf.function-decorated function tried to # create variables on non-first call" errors. on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) client_keys, actual_num_tokens = keys_for_client( batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices( np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE), dtype=np.float32)) keys, delta = client_train_fn( on_device_model, client_optimizer, model_slices_as_dataset, client_data=batched_dataset3, client_keys=client_keys, actual_num_tokens=actual_num_tokens) print(delta)

IndexedSlices を集約する

tff.federated_aggregate を使用して、IndexedSlices のスパースな連合集約を作成します。この単純な実装には、density_shape が事前に静的に認識されているという制約があります。また、この集約は、セミスパース(クライアント->サーバー通信がスパース)ですが、サーバーは accumulate およびmerge で集約の密な表現を維持し、この密な表現を出力していることにも注意してください。

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape): """ Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor. Intermediate aggregation is performed by converting to a dense representation, which may not be suitable for all applications. Args: slice_indices: An IndexedSlices.indices tensor @CLIENTS. slice_values: An IndexedSlices.values tensor @CLIENTS. dense_shape: A statically known dense shape. Returns: A dense tensor placed @SERVER representing the sum of the client's IndexedSclies. """ slices_dtype = slice_values.type_signature.member.dtype zero = tff.tf_computation( lambda: tf.zeros(dense_shape, dtype=slices_dtype))() @tf.function def accumulate_slices(dense, client_value): indices, slices = client_value # There is no built-in way to add `IndexedSlices`, but # tf.convert_to_tensor is a quick way to convert to a dense representation # so we can add them. return dense + tf.convert_to_tensor( tf.IndexedSlices(slices, indices, dense_shape)) return tff.federated_aggregate( (slice_indices, slice_values), zero=zero, accumulate=tff.tf_computation(accumulate_slices), merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')), report=tff.tf_computation(lambda d: d))

テストとして最小限の federated_computation を作成します

dense_shape = (6, 2) indices_type = tff.TensorType(tf.int64, (None,)) values_type = tff.TensorType(tf.float32, (None, 2)) client_slice_type = tff.type_at_clients( (indices_type, values_type)) @tff.federated_computation(client_slice_type) def test_sum_indexed_slices(indices_values_at_client): indices, values = indices_values_at_client return federated_indexed_slices_sum(indices, values, dense_shape) print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices( values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]], dtype=np.float32), indices=[2, 0, 1, 5], dense_shape=dense_shape) y = tf.IndexedSlices( values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32), indices=[1, 3], dense_shape=dense_shape) # Sum one. result = test_sum_indexed_slices([(x.indices, x.values)]) np.testing.assert_array_equal(tf.convert_to_tensor(x), result) # Sum two. expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]] result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)]) np.testing.assert_array_almost_equal(expected, result)

federated_computation に全てをまとめる

TFF を使用して、コンポーネントを tff.federated_computation にまとめます。

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) client_data_type = tff.SequenceType(batched_dataset1.element_spec) model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

連合平均化に基づく基本的なサーバートレーニング関数を使用し、サーバー学習率 1.0 で更新を適用します。クライアント提供のモデルを単純に平均化するのではなく、モデルに更新(デルタ)を適用することが重要です。そうでなければ、モデルの特定のスライスが特定のラウンドでどのクライアントによってもトレーニングされていない場合、その係数はゼロになる可能性があります。

@tff.tf_computation def server_update(current_model_weights, update_sum, num_clients): average_update = update_sum / num_clients return current_model_weights + average_update

さらにいくつかの tff.tf_computation の要素が必要です。

# Function to select slices from the model weights in federated_select: select_fn = tff.tf_computation( lambda model_weights, index: tf.gather(model_weights, index)) # We need to wrap `client_train_fn` as a `tff.tf_computation`, making # sure we do any operations that might construct `tf.Variable`s outside # of the `tf.function` we are wrapping. @tff.tf_computation def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys, actual_num_tokens): # Note this is amaller than the global model, using # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE. # We would like a model of size `actual_num_tokens`, but we # can't build the model dynamically, so we will slice off the padded # weights at the end. client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) return client_train_fn(client_model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens) @tff.tf_computation def keys_for_client_tff(client_data): return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

以上ですべての要素をまとめる準備が整いました。

@tff.federated_computation( tff.type_at_server(model_type), tff.type_at_clients(client_data_type)) def sparse_model_update(server_model, client_data): max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER) keys_at_clients, actual_num_tokens = tff.federated_map( keys_for_client_tff, client_data) model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model, select_fn) update_keys, update_slices = tff.federated_map( client_train_fn_tff, (model_slices, client_data, keys_at_clients, actual_num_tokens)) dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices, DENSE_MODEL_SHAPE) num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS)) updated_server_model = tff.federated_map( server_update, (server_model, dense_update_sum, num_clients)) return updated_server_model print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

モデルをトレーニングしましょう

トレーニング関数ができたので、試してみましょう。

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) server_model.compile( # Compile to make evaluation easy. optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused loss=tf.keras.losses.BinaryCrossentropy(), metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.AUC(name='auc'), tf.keras.metrics.Recall(top_k=2, name='recall_at_2'), ]) def evaluate(model, dataset, name): metrics = model.evaluate(dataset, verbose=0) metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in (zip(server_model.metrics_names, metrics))]) print(f'{name}: {metrics_str}')
print('Before training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3') model_weights = server_model.trainable_weights[0] client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3] for _ in range(10): # Run 10 rounds of FedAvg # We train on 1, 2, or 3 clients per round, selecting # randomly. cohort_size = np.random.randint(1, 4) clients = np.random.choice([0, 1, 2], cohort_size, replace=False) print('Training on clients', clients) model_weights = sparse_model_update( model_weights, [client_datasets[i] for i in clients]) server_model.set_weights([model_weights]) print('After training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80