Path: blob/master/site/en-snapshot/federated/tutorials/sparse_federated_learning.ipynb
25118 views
Copyright 2021 The TensorFlow Federated Authors.
Client-efficient large-model federated learning via federated_select
and sparse aggregation
This tutorial shows how TFF can be used to train a very large model where each client device only downloads and updates a small part of the model, using tff.federated_select
and sparse aggregation. While this tutorial is fairly self-contained, the tff.federated_select
tutorial and custom FL algorithms tutorial provide good introductions to some of the techniques used here.
Concretely, in this tutorial we consider logistic regression for multi-label classification, predicting which "tags" are associated with a text string based on a bag-of-words feature representation. Importantly, communication and client-side computation costs are controlled by a fixed constant (MAX_TOKENS_SELECTED_PER_CLIENT
), and do not scale with the overall vocabulary size, which could be extremely large in practical settings.
Each client will federated_select
the rows of the model weights for at most this many unique tokens. This upper-bounds the size of the client's local model and the amount of server -> client (federated_select
) and client - > server (federated_aggregate
) communication performed.
This tutorial should still run correctly even if you set this as small as 1 (ensuring not all tokens from each client are selected) or to a large value, though model convergence may be effected.
We also define a few constants for various types. For this colab, a token is an integer identifier for a particular word after parsing the dataset.
Setting up the problem: Dataset and Model
We construct a tiny toy dataset for easy experimentation in this tutorial. However, the format of the dataset is compatible with Federated StackOverflow, and the pre-processing and model architecture are adopted from the StackOverflow tag prediction problem of Adaptive Federated Optimization.
Dataset parsing and pre-processing
A tiny toy dataset
We construct a tiny toy dataset with a global vocabulary of 12 words and 3 clients. This tiny example is useful for testing edge cases (for example, we have two clients with less than MAX_TOKENS_SELECTED_PER_CLIENT = 6
distinct tokens, and one with more) and developing the code.
However, the real-world use cases of this approach would be global vocabularies of 10s of millions or more, with perhaps 1000s of distinct tokens appearing on each client. Because the format of the data is the same, the extension to more realistic testbed problems, e.g. the tff.simulation.datasets.stackoverflow.load_data()
dataset, should be straightforward.
First, we define our word and tag vocabularies.
Now, we create 3 clients with small local datasets. If you are running this tutorial in colab, it may be useful to use the "mirror cell in tab" feature to pin this cell and its output in order to interpret/check the output of the functions developed below.
Define constants for the raw numbers of input features (tokens/words) and labels (post tags). Our actual input/output spaces are NUM_OOV_BUCKETS = 1
larger because we add an OOV token / tag.
Create batched versions of the datasets, and individual batches, which will be useful in testing code as we go.
Define a model with sparse inputs
We use a simple independent logistic regression model for each tag.
Let's make sure it works, first by making predictions:
And some simple centralized training:
Building blocks for the federated computation
We will implement a simple version of the Federated Averaging algorithm with the key difference that each device only downloads a relevant subset of the model, and only contributes updates to that subset.
We use M
as shorthand for MAX_TOKENS_SELECTED_PER_CLIENT
. At a high level, one round of training involves these steps:
Each participating client scans over its local dataset, parsing the input strings and mapping them to the correct tokens (int indexes). This requires access to the global (large) dictionary (this could potentially be avoided using feature hashing techniques). We then sparsely count how many times each token occurs. If
U
unique tokens occur on device, we choose thenum_actual_tokens = min(U, M)
most frequent tokens to train.The clients use
federated_select
to retrieve the model coefficients for thenum_actual_tokens
selected tokens from the server. Each model slice is a tensor of shape(TAG_VOCAB_SIZE, )
, so the total data transmitted to the client is at most of sizeTAG_VOCAB_SIZE * M
(see note below).The clients construct a mapping
global_token -> local_token
where the local token (int index) is the index of the global token in the list of selected tokens.The clients use a "small" version of the global model that only has coefficients for at most
M
tokens, from the range[0, num_actual_tokens)
. Theglobal -> local
mapping is used to initialize the dense parameters of this model from the selected model slices.Clients train their local model using SGD on data preprocessed with the
global -> local
mapping.Clients turn the parameters of their local model into
IndexedSlices
updates using thelocal -> global
mapping to index the rows. The server aggregates these updates using a sparse sum aggregation.The server takes the (dense) result of the above aggregation, divides it by the number of clients participating, and applies the resulting average update to the global model.
In this section we construct the building blocks for these steps, which will then be combined in a final federated_computation
that captures the full logic of one training round.
NOTE: The above description hides one technical detail: Both
federated_select
and the construction of the local model require statically known shapes, and so we cannot use the dynamic per-clientnum_actual_tokens
size. Instead, we use the static valueM
, adding padding where needed. This does not impact that semantics of the algorithm.
Count client tokens and decide which model slices to federated_select
Each device needs to decide which "slices" of the model are relevant to its local training dataset. For our problem, we do this by (sparsely!) counting how many examples contain each token in the client training data set.
We will select the model parameters corresponding to the MAX_TOKENS_SELECTED_PER_CLIENT
most frequently occuring tokens on device. If fewer than this many tokens occur on device, we pad the list to enable the use of federated_select
.
Note that other strategies are possibly better, for example, randomly selecting tokens (perhaps based on their occurrence probability). This would ensure that all slices of the model (for which the client has data) have some chance of being updated.
Map global tokens to local tokens
The above selection gives us a dense set of tokens in the range [0, actual_num_tokens)
which we will use for the on-device model. However, the dataset we read has tokens from the much larger global vocabulary range [0, WORD_VOCAB_SIZE)
.
Thus, we need to map the global tokens to their corresponding local tokens. The local token ids are simply given by the indexes into the selected_tokens
tensor computed in the previous step.
Train the local (sub)model on each client
Note federated_select
will return the selected slices as a tf.data.Dataset
in the same order as the selection keys. So, we first define a utility function to take such a Dataset and convert it to a single dense tensor which can be used as the model weights of the client model.
We now have all the components we need to define a simple local training loop which will run on each client.
Aggregate IndexedSlices
We use tff.federated_aggregate
to construct a federated sparse sum for IndexedSlices
. This simple implementation has the constraint that the dense_shape
is known statically in advance. Note also that this sum is only semi-sparse, in the sense that the client -> server communication is sparse, but the server maintains a dense representation of the sum in accumulate
and merge
, and outputs this dense representation.
Construct a minimal federated_computation
as a test
Putting it all together in a federated_computation
We now use TFF to bind together the components into a tff.federated_computation
.
We use a basic server training function based on Federated Averaging, applying the update with a server learning rate of 1.0. It is important that we apply an update (delta) to the model, rather than simply averaging client-supplied models, as otherwise if a given slice of the model wasn't trained on by any client on a given round its coefficients could be zeroed out.
We need a couple more tff.tf_computation
components:
We're now ready to put all the pieces together!
Let's train a model!
Now that we have our training function, let's try it out.