Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/pt-br/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Aprendizado federado de modelos grandes eficiente nos clientes via federated_select e agregação esparsa

Este tutorial mostra como o TFF pode ser usado para treinar um modelo muito grande, em que cada dispositivo cliente somente baixe e atualize uma pequena parte do modelo usando tff.federated_select e agregação esparsa. Embora este tutorial seja relativamente autônomo, o tutorial sobre tff.federated_select e o tutorial sobre algoritmos personalizados de aprendizado de máquina apresentam boas introduções a algumas das técnicas usadas aqui.

Neste tutorial, consideramos a regressão linguística para classificação com vários rótulos, prevendo quais "tags" estão associadas a uma string de texto com base em uma representação de características "saco-de-palavras". A comunicação e os custos de computação no lado do cliente são controlados por uma constante fixa (MAX_TOKENS_SELECTED_PER_CLIENT) e não escalonam junto com o tamanho geral do vocabulário, que pode ser extremamente grande em cenários práticos.

#@test {"skip": true} !pip install --quiet --upgrade tensorflow-federated
import collections from collections.abc import Callable import itertools import numpy as np import tensorflow as tf import tensorflow_federated as tff

Cada cliente fará a seleção federada (federated_select) das linhas dos pesos do modelo para, no máximo, essa quantidade de tokens únicos. Isso limita o tamanho máximo do modelo local do cliente e a quantidade de comunicação servidor -> cliente (federated_select) e cliente -> servidor (federated_aggregate) realizada.

Este tutorial ainda deverá ser executado corretamente mesmo se você definir esse valor bem pequeno, como 1 (garantindo que nem todos os tokens de cada cliente sejam selecionados), ou se definir um valor grande, embora a convergência do modelo possa ser afetada.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

Também definimos algumas constantes de diversos tipos. Para este Colab, um token é um identificador inteiro para uma palavra específica após processar o dataset.

# There are some constraints on types # here that will require some explicit type conversions: # - `tff.federated_select` requires int32 # - `tf.SparseTensor` requires int64 indices. TOKEN_DTYPE = tf.int64 SELECT_KEY_DTYPE = tf.int32 # Type for counts of token occurences. TOKEN_COUNT_DTYPE = tf.int32 # A sparse feature vector can be thought of as a map # from TOKEN_DTYPE to FEATURE_DTYPE. # Our features are {0, 1} indicators, so we could potentially # use tf.int8 as an optimization. FEATURE_DTYPE = tf.int32

Definição do problema – Dataset e modelo

Construímos um dataset de exemplo minúsculo para fácil experimentação neste tutorial. Porém, o formato do dataset é compatível com Federated StackOverflow, e o pré-processamento e a arquitetura do modelo são os mesmos do problema de previsão de tags do StackOverflow em Otimização federada adaptativa.

Leitura e processamento do dataset

NUM_OOV_BUCKETS = 1 BatchType = collections.namedtuple('BatchType', ['tokens', 'tags']) def build_to_ids_fn(word_vocab: list[str], tag_vocab: list[str]) -> Callable[[tf.Tensor], tf.Tensor]: """Constructs a function mapping examples to sequences of token indices.""" word_table_values = np.arange(len(word_vocab), dtype=np.int64) word_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values), num_oov_buckets=NUM_OOV_BUCKETS) tag_table_values = np.arange(len(tag_vocab), dtype=np.int64) tag_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values), num_oov_buckets=NUM_OOV_BUCKETS) def to_ids(example): """Converts a Stack Overflow example to a bag-of-words/tags format.""" sentence = tf.strings.join([example['tokens'], example['title']], separator=' ') # We represent that label (output tags) densely. raw_tags = example['tags'] tags = tf.strings.split(raw_tags, sep='|') tags = tag_table.lookup(tags) tags, _ = tf.unique(tags) tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS) tags = tf.reduce_max(tags, axis=0) # We represent the features as a SparseTensor of {0, 1}s. words = tf.strings.split(sentence) tokens = word_table.lookup(words) tokens, _ = tf.unique(tokens) # Note: We could choose to use the word counts as the feature vector # instead of just {0, 1} values (see tf.unique_with_counts). tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1)) tokens_st = tf.SparseTensor( tokens, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,)) tokens_st = tf.sparse.reorder(tokens_st) return BatchType(tokens_st, tags) return to_ids
def build_preprocess_fn(word_vocab, tag_vocab): @tf.function def preprocess_fn(dataset): to_ids = build_to_ids_fn(word_vocab, tag_vocab) # We *don't* shuffle in order to make this colab deterministic for # easier testing and reproducibility. # But real-world training should use `.shuffle()`. return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return preprocess_fn

Um dataset de exemplo minúsculo

Construímos um dataset de exemplo minúsculo com um vocabulário global de 12 palavras e 3 clientes. Esse exemplo minúsculo é útil para testar casos extremos (por exemplo: temos dois clientes com menos de MAX_TOKENS_SELECTED_PER_CLIENT = 6 tokens distintos e um com mais) e para desenvolver o código.

Porém, os casos de uso reais dessa estratégia seriam vocabulários globais com dezenas de milhões de palavras ou mais, com talvez milhares de tokens distintos em cada cliente. Como o formato dos dados é igual, a extensão para problemas de teste mais realistas, como o dataset tff.simulation.datasets.stackoverflow.load_data(), seria bem direta.

Primeiro, definimos os vocabulários de palavras e tags.

# Features FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi'] VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas'] FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon'] WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS # Labels TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

Agora, criamos 3 clientes com datasets locais pequenos. Se você estiver executando este tutorial no Colab, pode ser útil usar o recurso "Mirror cell in tab" para fixar essa célula e sua saída para poder interpretar/verificar a saída das funções desenvolvidas abaixo.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB) def make_dataset(raw): d = tf.data.Dataset.from_tensor_slices( # Matches the StackOverflow formatting collections.OrderedDict( tokens=tf.constant([t[0] for t in raw]), tags=tf.constant([t[1] for t in raw]), title=['' for _ in raw])) d = preprocess_fn(d) return d # 4 distinct tokens CLIENT1_DATASET = make_dataset([ ('apple orange apple orange', 'FRUIT'), ('carrot trout', 'VEGETABLE|FISH'), ('orange apple', 'FRUIT'), ('orange', 'ORANGE|CITRUS') # 2 OOV tag ]) # 6 distinct tokens CLIENT2_DATASET = make_dataset([ ('pear cod', 'FRUIT|FISH'), ('arugula peas', 'VEGETABLE'), ('kiwi pear', 'FRUIT'), ('sturgeon', 'FISH'), # OOV word ('sturgeon bass', 'FISH') # 2 OOV words ]) # A client with all possible words & tags (13 distinct tokens). # With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model # slices for all tokens that occur on this client. CLIENT3_DATASET = make_dataset([ (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)), # Mathe the OOV token and 'salmon' occur in the largest number # of examples on this client: ('salmon oovword', 'FISH|OOVTAG') ]) print('Word vocab') for i, word in enumerate(WORD_VOCAB): print(f'{i:2d} {word}') print('\nTag vocab') for i, tag in enumerate(TAG_VOCAB): print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH

Defina constantes para os números brutos das características de entrada (tokens/palavras) e rótulos (tags). Nossos espaços reais de entrada/saída são NUM_OOV_BUCKETS = 1 maiores, pois adicionamos um token/tag fora do vocabulário.

NUM_WORDS = len(WORD_VOCAB) NUM_TAGS = len(TAG_VOCAB) WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

Crie versões dos datasets divididas em lotes e lotes individuais, que serão úteis no código de teste à medida que prosseguirmos.

batched_dataset1 = CLIENT1_DATASET.batch(2) batched_dataset2 = CLIENT2_DATASET.batch(3) batched_dataset3 = CLIENT3_DATASET.batch(2) batch1 = next(iter(batched_dataset1)) batch2 = next(iter(batched_dataset2)) batch3 = next(iter(batched_dataset3))

Defina um modelo com entradas esparsas

Usamos um modelo simples de regressão logística independente para cada tag.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int): model = tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True), tf.keras.layers.Dense( vocab_tags_size, activation='sigmoid', kernel_initializer=tf.keras.initializers.zeros, # For simplicity, don't use a bias vector; this means the model # is a single tensor, and we only need sparse aggregation of # the per-token slices of the model. Generalizing to also handle # other model weights that are fully updated # (non-dense broadcast and aggregate) would be a good exercise. use_bias=False), ]) return model

Vamos confirmar se está funcionando fazendo previsões primeiro:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) p = model.predict(batch1.tokens) print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]

E um treinamento centralizado simples:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001), loss=tf.keras.losses.BinaryCrossentropy()) model.train_on_batch(batch1.tokens, batch1.tags)

Blocos de construção para a computação federada

Vamos implementar uma versão simples do algoritmo de cálculo federado de médias, com a diferença chave de que cada dispositivo baixa apenas um subconjunto relevante do modelo e contribui apenas com atualizações desse subconjunto.

Usamos M como abreviação para MAX_TOKENS_SELECTED_PER_CLIENT. De forma geral, uma rodada de treinamento envolve estas etapas:

  1. Cada cliente participante varre seu próprio dataset local, processando as strings de entrada e mapeando-as para os tokens corretos (índices inteiros). Isso requer acesso ao dicionário (grande) global (isso pode ser possivelmente evitado usando técnicas de hash de características). Em seguida, fazemos a contagem esparsa de quantas vezes cada token ocorre. Se U tokens únicos ocorrerem no dispositivo, escolhemos os num_actual_tokens = min(U, M) tokens mais frequentes para fazer o treinamento.

  2. Os clientes usam federated_select para obter os coeficientes do modelo para os num_actual_tokens tokens selecionados a partir do servidor. Cada fatia do modelo é um tensor de formato (TAG_VOCAB_SIZE, ), então o total de dados transmitidos para o cliente tem tamanho máximo de TAG_VOCAB_SIZE * M (confira a observação abaixo).

  3. Os clientes constroem um mapeamento global_token -> local_token, em que o token local (índice inteiro) é o índice do token global na lista de tokens selecionados.

  4. Os clientes usam uma versão "pequena" do modelo global que tem apenas os coeficientes de no máximo M tokens do intervalo [0, num_actual_tokens). O mapeamento global -> local é usado para inicializar os parâmetros densos desse modelo a partir das fatias do modelo selecionadas.

  5. Os clientes treinam seu modelo local usando o método do gradiente estocástico com dados pré-processados com o mapeamento global -> local.

  6. Os clientes transformam os parâmetros de seu modelo local em IndexedSlices atualizações usando o mapeamento local -> global para indexar as linhas. O servidor agrega essas atualizações usando uma agregação de soma esparsa.

  7. O servidor pega o resultado (denso) da agregação acima, divide pelo número de clientes participantes e aplica a atualização da média resultante ao modelo global.

Nesta seção, montamos os blocos de construção para essas etapas, que serão combinados em uma computação federada (federated_computation) final que captura a lógica completa de uma rodada de treinamento.

OBSERVAÇÃO: a descrição acima oculta um detalhe técnico: tanto federated_select quanto a construção do modelo local exigem formatos estatisticamente conhecidos e, portanto, não podemos usar o tamanho num_actual_tokens dinâmico por cliente. Em vez disso, usamos o valor estático M, adicionando preenchimento onde necessário. Isso não impacta a semântica do algoritmo.

Conte os tokens de clientes e decida quais fatias do modelo devem ser selecionadas via federated_select

Cada dispositivo precisa decidir quais "fatias" do modelo são relevantes para seu dataset de treinamento local. Para o nosso problema, fazemos isso (esparsamente!) contando quantos exemplos contêm cada token no dataset de treinamento do cliente.

@tf.function def token_count_fn(token_counts, batch): """Adds counts from `batch` to the running `token_counts` sum.""" # Sum across the batch dimension. flat_tokens = tf.sparse.reduce_sum( batch.tokens, axis=0, output_is_sparse=True) flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE) return tf.sparse.add(token_counts, flat_tokens)
# Simple tests # Create the initial zero token counts using empty tensors. initial_token_counts = tf.SparseTensor( indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE), values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = batched_dataset1.reduce(initial_token_counts, token_count_fn) tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy() print('tokens:', tokens) np.testing.assert_array_equal(tokens, [0, 1, 4, 8]) # The count is the number of *examples* in which the token/word # occurs, not the total number of occurences, since we still featurize # multiple occurences in the same example as a "1". counts = client_token_counts.values.numpy() print('counts:', counts) np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8] counts: [2 3 1 1]

Vamos selecionar os parâmetros do modelo que correspondem aos MAX_TOKENS_SELECTED_PER_CLIENT tokens que ocorrem com maior frequência no dispositivo. Se um número menor que esse de tokens ocorrer no dispositivo, preenchemos a lista para permitir o uso de federated_select.

Observe que outras estratégias são possivelmente melhores, como, por exemplo, selecionar tokens aleatoriamente (talvez com base na probabilidade de ocorrência), o que garantiria que todas as fatias do modelo (para as quais o cliente tem dados) tenham alguma chance de serem atualizadas.

@tf.function def keys_for_client(client_dataset, max_tokens_per_client): """Computes a set of max_tokens_per_client keys.""" initial_token_counts = tf.SparseTensor( indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE), values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = client_dataset.reduce(initial_token_counts, token_count_fn) # Find the most-frequently occuring tokens tokens = tf.reshape(client_token_counts.indices, shape=(-1,)) counts = client_token_counts.values perm = tf.argsort(counts, direction='DESCENDING') tokens = tf.gather(tokens, perm) counts = tf.gather(counts, perm) num_raw_tokens = tf.shape(tokens)[0] actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens) selected_tokens = tokens[:actual_num_tokens] paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]] padded_tokens = tf.pad(selected_tokens, paddings=paddings) # Make sure the type is statically determined padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,)) # We will pass these tokens as keys into `federated_select`, which # requires SELECT_KEY_DTYPE=tf.int32 keys. padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE) return padded_tokens, actual_num_tokens
# Simple test # Case 1: actual_num_tokens > max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3) assert tf.size(selected_tokens) == 3 assert actual_num_tokens == 3 # Case 2: actual_num_tokens < max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10) assert tf.size(selected_tokens) == 10 assert actual_num_tokens == 4

Mapeie tokens globais em tokens locais

A seleção acima nos fornece um conjunto denso de tokens no intervalo [0, actual_num_tokens), que usaremos para o modelo no dispositivo. Porém, o dataset que lemos tem tokens do intervalo do vocabulário global muito maior, [0, WORD_VOCAB_SIZE).

Portanto, precisamos mapear os tokens globais em seus tokens locais correspondentes. Os IDs dos tokens locais são fornecidos simplesmente pelos índices ao tensor selected_tokens computado na etapa anterior.

@tf.function def map_to_local_token_ids(client_data, client_keys): global_to_local = tf.lookup.StaticHashTable( # Note int32 -> int64 maps are not supported tf.lookup.KeyValueTensorInitializer( keys=tf.cast(client_keys, dtype=TOKEN_DTYPE), # Note we need to use tf.shape, not the static # shape client_keys.shape[0] values=tf.range(0, limit=tf.shape(client_keys)[0], dtype=TOKEN_DTYPE)), # We use -1 for tokens that were not selected, which can occur for clients # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. # We will simply remove these invalid indices from the batch below. default_value=-1) def to_local_ids(sparse_tokens): indices_t = tf.transpose(sparse_tokens.indices) batch_indices = indices_t[0] # First column tokens = indices_t[1] # Second column tokens = tf.map_fn( lambda global_token_id: global_to_local.lookup(global_token_id), tokens) # Remove tokens that aren't actually available (looked up as -1): available_tokens = tokens >= 0 tokens = tokens[available_tokens] batch_indices = batch_indices[available_tokens] updated_indices = tf.transpose( tf.concat([[batch_indices], [tokens]], axis=0)) st = tf.sparse.SparseTensor( updated_indices, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), # Each client has at most MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. dense_shape=[sparse_tokens.dense_shape[0], MAX_TOKENS_SELECTED_PER_CLIENT]) st = tf.sparse.reorder(st) return st return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test client_keys, actual_num_tokens = keys_for_client( batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT) client_keys = client_keys[:actual_num_tokens] d = map_to_local_token_ids(batched_dataset3, client_keys) batch = next(iter(d)) all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1) # Confirm we have local indices in the range [0, MAX): assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT assert tf.math.reduce_max(all_tokens) >= 0

Treine o (sub)modelo local em cada cliente

Observe que federated_select retornará as fatias selecionadas como um tf.data.Dataset na mesma ordem que as chaves de seleção. Portanto, primeiro definimos uma função utilitária para receber um Dataset como esse e convertê-lo em um único tensor denso, que pode ser usado como os pesos do modelo do cliente.

@tf.function def slices_dataset_to_tensor(slices_dataset): """Convert a dataset of slices to a tensor.""" # Use batching to gather all of the slices into a single tensor. d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT, drop_remainder=False) iter_d = iter(d) tensor = next(iter_d) # Make sure we have consumed everything opt = iter_d.get_next_as_optional() tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY') return tensor
# Simple test weights = np.random.random( size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights) weights2 = slices_dataset_to_tensor(model_slices_as_dataset) np.testing.assert_array_equal(weights, weights2)

Agora temos todos os componentes necessários para definir um loop de treinamento simples que será executado em cada cliente.

@tf.function def client_train_fn(model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens): initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset) assert len(model.trainable_variables) == 1 model.trainable_variables[0].assign(initial_model_weights) # Only keep the "real" (unpadded) keys. client_keys = client_keys[:actual_num_tokens] client_data = map_to_local_token_ids(client_data, client_keys) loss_fn = tf.keras.losses.BinaryCrossentropy() for features, labels in client_data: with tf.GradientTape() as tape: predictions = model(features) loss = loss_fn(labels, predictions) grads = tape.gradient(loss, model.trainable_variables) client_optimizer.apply_gradients(zip(grads, model.trainable_variables)) model_weights_delta = model.trainable_weights[0] - initial_model_weights model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], size=[actual_num_tokens, -1]) return client_keys, model_weights_delta
# Simple test # Note if you execute this cell a second time, you need to also re-execute # the preceeding cell to avoid "tf.function-decorated function tried to # create variables on non-first call" errors. on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) client_keys, actual_num_tokens = keys_for_client( batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices( np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE), dtype=np.float32)) keys, delta = client_train_fn( on_device_model, client_optimizer, model_slices_as_dataset, client_data=batched_dataset3, client_keys=client_keys, actual_num_tokens=actual_num_tokens) print(delta)

IndexedSlices agregados

Usamos tff.federated_aggregate para construir uma soma esparsa federada para IndexedSlices. Essa implementação simples tem a restrição de que dense_shape é conhecido antecipadamente de forma estática. Observe também que essa soma é somente semiesparsa, no sentido de que a comunicação cliente -> servidor é esparsa, mas o servidor mantém uma representação densa da soma em accumulate e merge, e gera como saída essa representação densa.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape): """ Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor. Intermediate aggregation is performed by converting to a dense representation, which may not be suitable for all applications. Args: slice_indices: An IndexedSlices.indices tensor @CLIENTS. slice_values: An IndexedSlices.values tensor @CLIENTS. dense_shape: A statically known dense shape. Returns: A dense tensor placed @SERVER representing the sum of the client's IndexedSclies. """ slices_dtype = slice_values.type_signature.member.dtype zero = tff.tf_computation( lambda: tf.zeros(dense_shape, dtype=slices_dtype))() @tf.function def accumulate_slices(dense, client_value): indices, slices = client_value # There is no built-in way to add `IndexedSlices`, but # tf.convert_to_tensor is a quick way to convert to a dense representation # so we can add them. return dense + tf.convert_to_tensor( tf.IndexedSlices(slices, indices, dense_shape)) return tff.federated_aggregate( (slice_indices, slice_values), zero=zero, accumulate=tff.tf_computation(accumulate_slices), merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')), report=tff.tf_computation(lambda d: d))

Para fins de teste, construa uma federated_computation mínima:

dense_shape = (6, 2) indices_type = tff.TensorType(tf.int64, (None,)) values_type = tff.TensorType(tf.float32, (None, 2)) client_slice_type = tff.type_at_clients( (indices_type, values_type)) @tff.federated_computation(client_slice_type) def test_sum_indexed_slices(indices_values_at_client): indices, values = indices_values_at_client return federated_indexed_slices_sum(indices, values, dense_shape) print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices( values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]], dtype=np.float32), indices=[2, 0, 1, 5], dense_shape=dense_shape) y = tf.IndexedSlices( values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32), indices=[1, 3], dense_shape=dense_shape) # Sum one. result = test_sum_indexed_slices([(x.indices, x.values)]) np.testing.assert_array_equal(tf.convert_to_tensor(x), result) # Sum two. expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]] result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)]) np.testing.assert_array_almost_equal(expected, result)

Juntando tudo em uma federated_computation

Agora usamos o TFF para juntar todos os componentes em uma computação federada (tff.federated_computation).

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) client_data_type = tff.SequenceType(batched_dataset1.element_spec) model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

Usamos uma função básica de treinamento do servidor baseada no cálculo federado de médias, aplicando a atualização com uma taxa de aprendizado do servidor igual a 1,0. É importante aplicarmos uma atualização (delta) ao modelo em vez de simplesmente fazer a média dos modelos fornecidos pelos clientes, pois, caso contrário, se uma determinada fatia do modelo não tiver sido treinada por qualquer cliente em uma determinada rodada, seus coeficientes podem ser iguais a zero.

@tff.tf_computation def server_update(current_model_weights, update_sum, num_clients): average_update = update_sum / num_clients return current_model_weights + average_update

Precisamos de mais alguns componentes da tff.tf_computation:

# Function to select slices from the model weights in federated_select: select_fn = tff.tf_computation( lambda model_weights, index: tf.gather(model_weights, index)) # We need to wrap `client_train_fn` as a `tff.tf_computation`, making # sure we do any operations that might construct `tf.Variable`s outside # of the `tf.function` we are wrapping. @tff.tf_computation def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys, actual_num_tokens): # Note this is amaller than the global model, using # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE. # We would like a model of size `actual_num_tokens`, but we # can't build the model dynamically, so we will slice off the padded # weights at the end. client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) return client_train_fn(client_model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens) @tff.tf_computation def keys_for_client_tff(client_data): return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

Agora está tudo pronto para juntarmos todas as peças!

@tff.federated_computation( tff.type_at_server(model_type), tff.type_at_clients(client_data_type)) def sparse_model_update(server_model, client_data): max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER) keys_at_clients, actual_num_tokens = tff.federated_map( keys_for_client_tff, client_data) model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model, select_fn) update_keys, update_slices = tff.federated_map( client_train_fn_tff, (model_slices, client_data, keys_at_clients, actual_num_tokens)) dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices, DENSE_MODEL_SHAPE) num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS)) updated_server_model = tff.federated_map( server_update, (server_model, dense_update_sum, num_clients)) return updated_server_model print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

Vamos treinar um modelo!

Agora que temos uma função de treinamento, vamos testá-la.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) server_model.compile( # Compile to make evaluation easy. optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused loss=tf.keras.losses.BinaryCrossentropy(), metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.AUC(name='auc'), tf.keras.metrics.Recall(top_k=2, name='recall_at_2'), ]) def evaluate(model, dataset, name): metrics = model.evaluate(dataset, verbose=0) metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in (zip(server_model.metrics_names, metrics))]) print(f'{name}: {metrics_str}')
print('Before training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3') model_weights = server_model.trainable_weights[0] client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3] for _ in range(10): # Run 10 rounds of FedAvg # We train on 1, 2, or 3 clients per round, selecting # randomly. cohort_size = np.random.randint(1, 4) clients = np.random.choice([0, 1, 2], cohort_size, replace=False) print('Training on clients', clients) model_weights = sparse_model_update( model_weights, [client_datasets[i] for i in clients]) server_model.set_weights([model_weights]) print('After training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80