Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/es-419/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Aprendizaje federado de modelos grandes eficiente para el cliente a través de federated_select y agregación dispersa

Este tutorial muestra cómo se puede usar TFF para entrenar un modelo muy grande donde cada dispositivo cliente solo descarga y actualiza una pequeña parte del modelo, con ayuda de tff.federated_select y agregación dispersa. Si bien este tutorial es bastante completo, el tutorial tff.federated_select y el tutorial de algoritmos de FL personalizados sirven de introducción a algunas de las técnicas que se utilizan aquí.

Concretamente, en este tutorial consideramos la regresión logística para la clasificación de etiquetas múltiples, prediciendo qué "etiquetas" están asociadas con una cadena de texto basada en una representación de características de bolsa de palabras. Es importante destacar que los costos de comunicación y cálculo del lado del cliente están controlados por una constante fija (MAX_TOKENS_SELECTED_PER_CLIENT) y no escalan con el tamaño general del vocabulario, que podría ser extremadamente grande en la práctica.

#@test {"skip": true} !pip install --quiet --upgrade tensorflow-federated
import collections from collections.abc import Callable import itertools import numpy as np import tensorflow as tf import tensorflow_federated as tff

Cada cliente usará federated_select para seleccionar las filas de las ponderaciones del modelo para, como máximo, esta cantidad de tokens únicos. Esto limita el tamaño del modelo local del cliente y la cantidad de comunicación servidor -> cliente (federated_select) y cliente - > servidor (federated_aggregate).

Este tutorial debería ejecutarse correctamente incluso si se ajusta a un valor tan pequeño como 1 (lo que garantiza que no se seleccionen todas las fichas de cada cliente) o a un valor grande, aunque la convergencia del modelo puede verse afectada.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

También definimos algunas constantes para varios tipos. Para esta colaboración, un token es un identificador entero para una palabra en particular después de analizar el conjunto de datos.

# There are some constraints on types # here that will require some explicit type conversions: # - `tff.federated_select` requires int32 # - `tf.SparseTensor` requires int64 indices. TOKEN_DTYPE = tf.int64 SELECT_KEY_DTYPE = tf.int32 # Type for counts of token occurences. TOKEN_COUNT_DTYPE = tf.int32 # A sparse feature vector can be thought of as a map # from TOKEN_DTYPE to FEATURE_DTYPE. # Our features are {0, 1} indicators, so we could potentially # use tf.int8 as an optimization. FEATURE_DTYPE = tf.int32

Configuración del problema: conjunto de datos y modelo

En este tutorial construimos un pequeño conjunto de datos de juguete para facilitar la experimentación. Sin embargo, el formato del conjunto de datos es compatible con Federated StackOverflow y la arquitectura del modelo y el preprocesamiento se adoptan del problema de predicción de etiquetas StackOverflow de Adaptive Federated Optimization.

Parseo y preprocesamiento de conjuntos de datos

NUM_OOV_BUCKETS = 1 BatchType = collections.namedtuple('BatchType', ['tokens', 'tags']) def build_to_ids_fn(word_vocab: list[str], tag_vocab: list[str]) -> Callable[[tf.Tensor], tf.Tensor]: """Constructs a function mapping examples to sequences of token indices.""" word_table_values = np.arange(len(word_vocab), dtype=np.int64) word_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values), num_oov_buckets=NUM_OOV_BUCKETS) tag_table_values = np.arange(len(tag_vocab), dtype=np.int64) tag_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values), num_oov_buckets=NUM_OOV_BUCKETS) def to_ids(example): """Converts a Stack Overflow example to a bag-of-words/tags format.""" sentence = tf.strings.join([example['tokens'], example['title']], separator=' ') # We represent that label (output tags) densely. raw_tags = example['tags'] tags = tf.strings.split(raw_tags, sep='|') tags = tag_table.lookup(tags) tags, _ = tf.unique(tags) tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS) tags = tf.reduce_max(tags, axis=0) # We represent the features as a SparseTensor of {0, 1}s. words = tf.strings.split(sentence) tokens = word_table.lookup(words) tokens, _ = tf.unique(tokens) # Note: We could choose to use the word counts as the feature vector # instead of just {0, 1} values (see tf.unique_with_counts). tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1)) tokens_st = tf.SparseTensor( tokens, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,)) tokens_st = tf.sparse.reorder(tokens_st) return BatchType(tokens_st, tags) return to_ids
def build_preprocess_fn(word_vocab, tag_vocab): @tf.function def preprocess_fn(dataset): to_ids = build_to_ids_fn(word_vocab, tag_vocab) # We *don't* shuffle in order to make this colab deterministic for # easier testing and reproducibility. # But real-world training should use `.shuffle()`. return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return preprocess_fn

Un pequeño conjunto de datos de juguete

Construimos un pequeño conjunto de datos de juguete con un vocabulario global de 12 palabras y 3 clientes. Este pequeño ejemplo sirve para probar casos extremos (por ejemplo, tenemos dos clientes con menos de MAX_TOKENS_SELECTED_PER_CLIENT = 6 tokens distintos y uno con más) y desarrollar el código.

No obstante, los casos de uso de este enfoque en el mundo real serían vocabularios globales de decenas de millones o más, con quizás miles de tokens distintos en cada cliente. Como el formato de los datos es el mismo, la extensión a problemas de banco de pruebas más realistas, por ejemplo, el conjunto de datos tff.simulation.datasets.stackoverflow.load_data(), debería ser sencilla.

En primer lugar, definimos nuestros vocabularios de palabras y etiquetas.

# Features FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi'] VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas'] FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon'] WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS # Labels TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

Ahora, creamos 3 clientes con pequeños conjuntos de datos locales. Si está ejecutando este tutorial en Colab, quizá le resulte útil la característica "reflejar celda en pestaña" para fijar esta celda y su salida con el fin de interpretar/comprobar la salida de las funciones que se desarrollan a continuación.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB) def make_dataset(raw): d = tf.data.Dataset.from_tensor_slices( # Matches the StackOverflow formatting collections.OrderedDict( tokens=tf.constant([t[0] for t in raw]), tags=tf.constant([t[1] for t in raw]), title=['' for _ in raw])) d = preprocess_fn(d) return d # 4 distinct tokens CLIENT1_DATASET = make_dataset([ ('apple orange apple orange', 'FRUIT'), ('carrot trout', 'VEGETABLE|FISH'), ('orange apple', 'FRUIT'), ('orange', 'ORANGE|CITRUS') # 2 OOV tag ]) # 6 distinct tokens CLIENT2_DATASET = make_dataset([ ('pear cod', 'FRUIT|FISH'), ('arugula peas', 'VEGETABLE'), ('kiwi pear', 'FRUIT'), ('sturgeon', 'FISH'), # OOV word ('sturgeon bass', 'FISH') # 2 OOV words ]) # A client with all possible words & tags (13 distinct tokens). # With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model # slices for all tokens that occur on this client. CLIENT3_DATASET = make_dataset([ (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)), # Mathe the OOV token and 'salmon' occur in the largest number # of examples on this client: ('salmon oovword', 'FISH|OOVTAG') ]) print('Word vocab') for i, word in enumerate(WORD_VOCAB): print(f'{i:2d} {word}') print('\nTag vocab') for i, tag in enumerate(TAG_VOCAB): print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH

Defina constantes para los números brutos de características de entrada (tokens/palabras) y etiquetas (etiquetas de publicación). Nuestros espacios de entrada/salida reales son NUM_OOV_BUCKETS = 1 más grandes porque agregamos un token/etiqueta OOV.

NUM_WORDS = len(WORD_VOCAB) NUM_TAGS = len(TAG_VOCAB) WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

Cree versiones por lotes de los conjuntos de datos y lotes individuales, que serán útiles para probar el código a medida que avanzamos.

batched_dataset1 = CLIENT1_DATASET.batch(2) batched_dataset2 = CLIENT2_DATASET.batch(3) batched_dataset3 = CLIENT3_DATASET.batch(2) batch1 = next(iter(batched_dataset1)) batch2 = next(iter(batched_dataset2)) batch3 = next(iter(batched_dataset3))

Definición de un modelo con entradas dispersas

Utilizamos un modelo de regresión logística independiente y simple para cada etiqueta.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int): model = tf.keras.models.Sequential([ tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True), tf.keras.layers.Dense( vocab_tags_size, activation='sigmoid', kernel_initializer=tf.keras.initializers.zeros, # For simplicity, don't use a bias vector; this means the model # is a single tensor, and we only need sparse aggregation of # the per-token slices of the model. Generalizing to also handle # other model weights that are fully updated # (non-dense broadcast and aggregate) would be a good exercise. use_bias=False), ]) return model

Hagamos predicciones para asegurarnos de que funciona:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) p = model.predict(batch1.tokens) print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]

Y algún entrenamiento centralizado simple:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001), loss=tf.keras.losses.BinaryCrossentropy()) model.train_on_batch(batch1.tokens, batch1.tags)

Bloques de creación para el cálculo federado

Vamos a implementar una versión simple del algoritmo de promediado federado con la diferencia clave de que cada dispositivo solo descarga un subconjunto relevante del modelo y solo aporta actualizaciones a ese subconjunto.

Usamos M como abreviatura de MAX_TOKENS_SELECTED_PER_CLIENT. A un nivel alto, una ronda de entrenamiento implica estos pasos:

  1. Cada cliente participante escanea su conjunto de datos local, a la vez que parsea las cadenas de entrada y las asigna a los tokens correctos (índices int). Esto requiere acceso al (enorme) diccionario global (esto podría evitarse si se usan técnicas de hash de características). A continuación, contamos de forma dispersa cuántas veces aparece cada token. Si aparecen U tokens únicos en el dispositivo, elegimos los tokens num_actual_tokens = min(U, M) más frecuentes para entrenar.

  2. Los clientes usan federated_select para recuperar los coeficientes del modelo para los tokens num_actual_tokens seleccionados del servidor. Cada segmento del modelo es un tensor de forma (TAG_VOCAB_SIZE, ), por lo que los datos totales que se transmiten al cliente tienen como máximo el tamaño TAG_VOCAB_SIZE * M (consulte la nota a continuación).

  3. Los clientes construyen una asignación global_token -> local_token donde el token local (índice int) es el índice del token global en la lista de tokens seleccionados.

  4. Los clientes usan una versión "pequeña" del modelo global que solo tiene coeficientes para como máximo M tokens, del rango [0, num_actual_tokens) . La asignación global -> local sirve para inicializar los parámetros densos de este modelo a partir de los sectores seleccionados del modelo.

  5. Los clientes entrenan su modelo local con SGD en datos preprocesados ​​con la asignación global -> local.

  6. Los clientes convierten los parámetros de su modelo local en actualizaciones IndexedSlices con ayuda de la asignación local -> global para indexar las filas. El servidor agrega estas actualizaciones mediante una agregación de suma dispersa.

  7. El servidor toma el resultado (denso) de la agregación anterior, lo divide por la cantidad de clientes participantes y aplica la actualización promedio resultante al modelo global.

En esta sección, construimos los bloques de creación de estos pasos, que luego se combinarán en un federated_computation final que captura la lógica completa de una ronda de entrenamiento.

NOTA: La descripción anterior oculta un detalle técnico: tanto federated_select como la construcción del modelo local requieren formas estáticamente conocidas, por lo que no podemos usar el tamaño dinámico de num_actual_tokens por cliente. En vez de eso, usamos el valor estático M y agregamos amortiguado cuando sea necesario. Esto no afecta la semántica del algoritmo.

Cuente los tokens cliente y decida qué modelo se divide en federated_select

Cada dispositivo debe decidir qué "porciones" del modelo son relevantes para su conjunto de datos de entrenamiento local. Para nuestro problema, hacemos esto contando (¡de forma dispersa!) cuántos ejemplos contiene cada token en el conjunto de datos de entrenamiento del cliente.

@tf.function def token_count_fn(token_counts, batch): """Adds counts from `batch` to the running `token_counts` sum.""" # Sum across the batch dimension. flat_tokens = tf.sparse.reduce_sum( batch.tokens, axis=0, output_is_sparse=True) flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE) return tf.sparse.add(token_counts, flat_tokens)
# Simple tests # Create the initial zero token counts using empty tensors. initial_token_counts = tf.SparseTensor( indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE), values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = batched_dataset1.reduce(initial_token_counts, token_count_fn) tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy() print('tokens:', tokens) np.testing.assert_array_equal(tokens, [0, 1, 4, 8]) # The count is the number of *examples* in which the token/word # occurs, not the total number of occurences, since we still featurize # multiple occurences in the same example as a "1". counts = client_token_counts.values.numpy() print('counts:', counts) np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8] counts: [2 3 1 1]

Seleccionaremos los parámetros del modelo correspondientes a los tokens MAX_TOKENS_SELECTED_PER_CLIENT que aparecen con más frecuencia en el dispositivo. Si hay menos tokens de esta cantidad en el dispositivo, rellenamos la lista para permitir el uso de federated_select.

Tenga en cuenta que es posible que otras estrategias sean mejores, por ejemplo, seleccionar tokens aleatoriamente (quizás en función de su probabilidad de aparición). Esto garantizaría que todos los sectores del modelo (para los cuales el cliente tiene datos) tengan alguna posibilidad de actualizarse.

@tf.function def keys_for_client(client_dataset, max_tokens_per_client): """Computes a set of max_tokens_per_client keys.""" initial_token_counts = tf.SparseTensor( indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE), values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE), dense_shape=(WORD_VOCAB_SIZE,)) client_token_counts = client_dataset.reduce(initial_token_counts, token_count_fn) # Find the most-frequently occuring tokens tokens = tf.reshape(client_token_counts.indices, shape=(-1,)) counts = client_token_counts.values perm = tf.argsort(counts, direction='DESCENDING') tokens = tf.gather(tokens, perm) counts = tf.gather(counts, perm) num_raw_tokens = tf.shape(tokens)[0] actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens) selected_tokens = tokens[:actual_num_tokens] paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]] padded_tokens = tf.pad(selected_tokens, paddings=paddings) # Make sure the type is statically determined padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,)) # We will pass these tokens as keys into `federated_select`, which # requires SELECT_KEY_DTYPE=tf.int32 keys. padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE) return padded_tokens, actual_num_tokens
# Simple test # Case 1: actual_num_tokens > max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3) assert tf.size(selected_tokens) == 3 assert actual_num_tokens == 3 # Case 2: actual_num_tokens < max_tokens_per_client selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10) assert tf.size(selected_tokens) == 10 assert actual_num_tokens == 4

Asignación de tokens globales a tokens locales

La selección anterior nos brinda un conjunto denso de tokens en el rango [0, actual_num_tokens) que usaremos para el modelo en el dispositivo. Sin embargo, el conjunto de datos que leemos tiene tokens de un rango de vocabulario global mucho más amplio [0, WORD_VOCAB_SIZE).

Por lo tanto, necesitamos asignar los tokens globales a sus tokens locales correspondientes. Los identificadores de tokens locales simplemente vienen dados por los índices en el tensor selected_tokens calculado en el paso anterior.

@tf.function def map_to_local_token_ids(client_data, client_keys): global_to_local = tf.lookup.StaticHashTable( # Note int32 -> int64 maps are not supported tf.lookup.KeyValueTensorInitializer( keys=tf.cast(client_keys, dtype=TOKEN_DTYPE), # Note we need to use tf.shape, not the static # shape client_keys.shape[0] values=tf.range(0, limit=tf.shape(client_keys)[0], dtype=TOKEN_DTYPE)), # We use -1 for tokens that were not selected, which can occur for clients # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. # We will simply remove these invalid indices from the batch below. default_value=-1) def to_local_ids(sparse_tokens): indices_t = tf.transpose(sparse_tokens.indices) batch_indices = indices_t[0] # First column tokens = indices_t[1] # Second column tokens = tf.map_fn( lambda global_token_id: global_to_local.lookup(global_token_id), tokens) # Remove tokens that aren't actually available (looked up as -1): available_tokens = tokens >= 0 tokens = tokens[available_tokens] batch_indices = batch_indices[available_tokens] updated_indices = tf.transpose( tf.concat([[batch_indices], [tokens]], axis=0)) st = tf.sparse.SparseTensor( updated_indices, tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE), # Each client has at most MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens. dense_shape=[sparse_tokens.dense_shape[0], MAX_TOKENS_SELECTED_PER_CLIENT]) st = tf.sparse.reorder(st) return st return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test client_keys, actual_num_tokens = keys_for_client( batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT) client_keys = client_keys[:actual_num_tokens] d = map_to_local_token_ids(batched_dataset3, client_keys) batch = next(iter(d)) all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1) # Confirm we have local indices in the range [0, MAX): assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT assert tf.math.reduce_max(all_tokens) >= 0

Entrenamiento del (sub)modelo local en cada cliente

Observe que federated_select devolverá los sectores seleccionados como tf.data.Dataset en el mismo orden que las claves de selección. Entonces, primero definimos una función de utilidad para tomar dicho conjunto de datos y convertirlo en un único tensor denso que pueda usarse como ponderaciones del modelo del cliente.

@tf.function def slices_dataset_to_tensor(slices_dataset): """Convert a dataset of slices to a tensor.""" # Use batching to gather all of the slices into a single tensor. d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT, drop_remainder=False) iter_d = iter(d) tensor = next(iter_d) # Make sure we have consumed everything opt = iter_d.get_next_as_optional() tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY') return tensor
# Simple test weights = np.random.random( size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights) weights2 = slices_dataset_to_tensor(model_slices_as_dataset) np.testing.assert_array_equal(weights, weights2)

Ahora tenemos todos los componentes que necesitamos para definir un ciclo de entrenamiento local simple que se ejecutará en cada cliente.

@tf.function def client_train_fn(model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens): initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset) assert len(model.trainable_variables) == 1 model.trainable_variables[0].assign(initial_model_weights) # Only keep the "real" (unpadded) keys. client_keys = client_keys[:actual_num_tokens] client_data = map_to_local_token_ids(client_data, client_keys) loss_fn = tf.keras.losses.BinaryCrossentropy() for features, labels in client_data: with tf.GradientTape() as tape: predictions = model(features) loss = loss_fn(labels, predictions) grads = tape.gradient(loss, model.trainable_variables) client_optimizer.apply_gradients(zip(grads, model.trainable_variables)) model_weights_delta = model.trainable_weights[0] - initial_model_weights model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], size=[actual_num_tokens, -1]) return client_keys, model_weights_delta
# Simple test # Note if you execute this cell a second time, you need to also re-execute # the preceeding cell to avoid "tf.function-decorated function tried to # create variables on non-first call" errors. on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001) client_keys, actual_num_tokens = keys_for_client( batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT) model_slices_as_dataset = tf.data.Dataset.from_tensor_slices( np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE), dtype=np.float32)) keys, delta = client_train_fn( on_device_model, client_optimizer, model_slices_as_dataset, client_data=batched_dataset3, client_keys=client_keys, actual_num_tokens=actual_num_tokens) print(delta)

Agregación de IndexedSlices

Usamos tff.federated_aggregate para construir una suma dispersa federada para IndexedSlices. Esta implementación simple tiene la restricción de que la dense_shape se conoce estáticamente de antemano. Tenga en cuenta también que esta suma es solo semidispersa, en el sentido de que la comunicación cliente -> servidor es dispersa, pero el servidor mantiene una representación densa de la suma en accumulate y merge, y genera esta representación densa.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape): """ Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor. Intermediate aggregation is performed by converting to a dense representation, which may not be suitable for all applications. Args: slice_indices: An IndexedSlices.indices tensor @CLIENTS. slice_values: An IndexedSlices.values tensor @CLIENTS. dense_shape: A statically known dense shape. Returns: A dense tensor placed @SERVER representing the sum of the client's IndexedSclies. """ slices_dtype = slice_values.type_signature.member.dtype zero = tff.tf_computation( lambda: tf.zeros(dense_shape, dtype=slices_dtype))() @tf.function def accumulate_slices(dense, client_value): indices, slices = client_value # There is no built-in way to add `IndexedSlices`, but # tf.convert_to_tensor is a quick way to convert to a dense representation # so we can add them. return dense + tf.convert_to_tensor( tf.IndexedSlices(slices, indices, dense_shape)) return tff.federated_aggregate( (slice_indices, slice_values), zero=zero, accumulate=tff.tf_computation(accumulate_slices), merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')), report=tff.tf_computation(lambda d: d))

Construya un federated_computation mínimo como prueba

dense_shape = (6, 2) indices_type = tff.TensorType(tf.int64, (None,)) values_type = tff.TensorType(tf.float32, (None, 2)) client_slice_type = tff.type_at_clients( (indices_type, values_type)) @tff.federated_computation(client_slice_type) def test_sum_indexed_slices(indices_values_at_client): indices, values = indices_values_at_client return federated_indexed_slices_sum(indices, values, dense_shape) print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices( values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]], dtype=np.float32), indices=[2, 0, 1, 5], dense_shape=dense_shape) y = tf.IndexedSlices( values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32), indices=[1, 3], dense_shape=dense_shape) # Sum one. result = test_sum_indexed_slices([(x.indices, x.values)]) np.testing.assert_array_equal(tf.convert_to_tensor(x), result) # Sum two. expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]] result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)]) np.testing.assert_array_almost_equal(expected, result)

Todo se combina en federated_computation

Ahora usamos TFF para unir los componentes en un tff.federated_computation.

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) client_data_type = tff.SequenceType(batched_dataset1.element_spec) model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

Usamos una función básica de entrenamiento del servidor basada en el promediado federado, aplicando la actualización con una tasa de aprendizaje del servidor de 1,0. Es importante que apliquemos una actualización (delta) al modelo, en lugar de simplemente promediar los modelos que proporciona el cliente, ya que, de lo contrario, si ningún cliente entrena una porción determinada del modelo en una ronda determinada, sus coeficientes podrían reducirse a cero.

@tff.tf_computation def server_update(current_model_weights, update_sum, num_clients): average_update = update_sum / num_clients return current_model_weights + average_update

Necesitamos un par de componentes tff.tf_computation más:

# Function to select slices from the model weights in federated_select: select_fn = tff.tf_computation( lambda model_weights, index: tf.gather(model_weights, index)) # We need to wrap `client_train_fn` as a `tff.tf_computation`, making # sure we do any operations that might construct `tf.Variable`s outside # of the `tf.function` we are wrapping. @tff.tf_computation def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys, actual_num_tokens): # Note this is amaller than the global model, using # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE. # We would like a model of size `actual_num_tokens`, but we # can't build the model dynamically, so we will slice off the padded # weights at the end. client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE) client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) return client_train_fn(client_model, client_optimizer, model_slices_as_dataset, client_data, client_keys, actual_num_tokens) @tff.tf_computation def keys_for_client_tff(client_data): return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

¡Ya estamos listos para armar todo!

@tff.federated_computation( tff.type_at_server(model_type), tff.type_at_clients(client_data_type)) def sparse_model_update(server_model, client_data): max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER) keys_at_clients, actual_num_tokens = tff.federated_map( keys_for_client_tff, client_data) model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model, select_fn) update_keys, update_slices = tff.federated_map( client_train_fn_tff, (model_slices, client_data, keys_at_clients, actual_num_tokens)) dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices, DENSE_MODEL_SHAPE) num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS)) updated_server_model = tff.federated_map( server_update, (server_model, dense_update_sum, num_clients)) return updated_server_model print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

¡Entrenemos un modelo!

Ahora que tenemos nuestra función de entrenamiento, probémosla.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE) server_model.compile( # Compile to make evaluation easy. optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused loss=tf.keras.losses.BinaryCrossentropy(), metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.AUC(name='auc'), tf.keras.metrics.Recall(top_k=2, name='recall_at_2'), ]) def evaluate(model, dataset, name): metrics = model.evaluate(dataset, verbose=0) metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in (zip(server_model.metrics_names, metrics))]) print(f'{name}: {metrics_str}')
print('Before training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3') model_weights = server_model.trainable_weights[0] client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3] for _ in range(10): # Run 10 rounds of FedAvg # We train on 1, 2, or 3 clients per round, selecting # randomly. cohort_size = np.random.randint(1, 4) clients = np.random.choice([0, 1, 2], cohort_size, replace=False) print('Training on clients', clients) model_weights = sparse_model_update( model_weights, [client_datasets[i] for i in clients]) server_model.set_weights([model_weights]) print('After training') evaluate(server_model, batched_dataset1, 'Client 1') evaluate(server_model, batched_dataset2, 'Client 2') evaluate(server_model, batched_dataset3, 'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80