Path: blob/master/site/en-snapshot/federated/openmined2020/openmined_conference_2020.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
Before we start
To edit the colab notebook, please go to "File" -> "Save a copy in Drive" and make any edits on your copy.
Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
TensorFlow Federated for Image Classification
Let's experiment with federated learning in simulation. In this tutorial, we use the classic MNIST training example to introduce the Federated Learning (FL) API layer of TFF, tff.learning
- a set of higher-level interfaces that can be used to perform common types of federated learning tasks, such as federated training, against user-supplied models implemented in TensorFlow.
Tutorial Outline
We'll be training a model to perform image classification using the classic MNIST dataset, with the neural net learning to classify digit from image. In this case, we'll be simulating federated learning, with the training data distributed on different devices.
Sections
Load TFF Libraries.
Explore/preprocess federated EMNIST dataset.
Create a model.
Set up federated averaging process for training.
Analyze training metrics.
Set up federated evaluation computation.
Analyze evaluation metrics.
Preparing the input data
Let's start with the data. Federated learning requires a federated data set, i.e., a collection of data from multiple users. Federated data is typically non-i.i.d., which poses a unique set of challenges. Users typically have different distributions of data depending on usage patterns.
In order to facilitate experimentation, we seeded the TFF repository with a few datasets.
Here's how we can load our sample dataset.
The data sets returned by load_data()
are instances of tff.simulation.datasets.ClientData
, an interface that allows you to enumerate the set of users, to construct a tf.data.Dataset
that represents the data of a particular user, and to query the structure of individual elements.
Let's explore the dataset.
Exploring non-iid data
Preprocessing the data
Since the data is already a tf.data.Dataset
, preprocessing can be accomplished using Dataset transformations. See here for more detail on these transformations.
Let's verify this worked.
Here's a simple helper function that will construct a list of datasets from the given set of users as an input to a round of training or evaluation.
Now, how do we choose clients?
Creating a model with Keras
If you are using Keras, you likely already have code that constructs a Keras model. Here's an example of a simple model that will suffice for our needs.
Centralized training with Keras
Federated training using a Keras model
In order to use any model with TFF, it needs to be wrapped in an instance of the tff.learning.Model
interface.
More keras metrics you can add are found here.
Training the model on federated data
Now that we have a model wrapped as tff.learning.Model
for use with TFF, we can let TFF construct a Federated Averaging algorithm by invoking the helper function tff.learning.build_federated_averaging_process
, as follows.
What just happened? TFF has constructed a pair of federated computations and packaged them into a tff.templates.IterativeProcess
in which these computations are available as a pair of properties initialize
and next
.
An iterative process will usually be driven by a control loop like:
Let's invoke the initialize
computation to construct the server state.
The second of the pair of federated computations, next
, represents a single round of Federated Averaging, which consists of pushing the server state (including the model parameters) to the clients, on-device training on their local data, collecting and averaging model updates, and producing a new updated model at the server.
Let's run a single round of training and visualize the results. We can use the federated data we've already generated above for a sample of users.
Let's run a few more rounds. As noted earlier, typically at this point you would pick a subset of your simulation data from a new randomly selected sample of users for each round in order to simulate a realistic deployment in which users continuously come and go, but in this interactive notebook, for the sake of demonstration we'll just reuse the same users, so that the system converges quickly.
Training loss is decreasing after each round of federated training, indicating the model is converging. There are some important caveats with these training metrics, however, see the section on Evaluation later in this tutorial.
##Displaying model metrics in TensorBoard Next, let's visualize the metrics from these federated computations using Tensorboard.
Let's start by creating the directory and the corresponding summary writer to write the metrics to.
Plot the relevant scalar metrics with the same summary writer.
Start TensorBoard with the root log directory specified above. It can take a few seconds for the data to load.
In order to view evaluation metrics the same way, you can create a separate eval folder, like "logs/scalars/eval", to write to TensorBoard.
Evaluation
To perform evaluation on federated data, you can construct another federated computation designed for just this purpose, using the tff.learning.build_federated_evaluation
function, and passing in your model constructor as an argument.
Now, let's compile a test sample of federated data and rerun evaluation on the test data. The data will come from a different sample of users, but from a distinct held-out data set.
This concludes the tutorial. We encourage you to play with the parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in each round, and to explore the other tutorials we've developed.
Build your own FL algorithms
In the previous tutorials, we learned how to set up model and data pipelines, and use these to perform federated training using the tff.learning
API.
Of course, this is only the tip of the iceberg when it comes to FL research. In this tutorial, we are going to discuss how to implement federated learning algorithms without deferring to the tff.learning
API. We aim to accomplish the following:
Goals:
Understand the general structure of federated learning algorithms.
Explore the Federated Core of TFF.
Use the Federated Core to implement Federated Averaging directly.
Preparing the input data
We first load and preprocess the EMNIST dataset included in TFF. We essentially use the same code as in the first tutorial.
Preparing the model
We use the same model as the first tutorial, which has a single hidden layer, followed by a softmax layer.
We wrap this Keras model as a tff.learning.Model
.
Cutomizing FL Algorithm
While the tff.learning
API encompasses many variants of Federated Averaging, there are many other algorithms that do not fit neatly into this framework. For example, you may want to add regularization, clipping, or more complicated algorithms such as federated GAN training. You may also be instead be interested in federated analytics.
For these more advanced algorithms, we'll have to write our own custom FL algorithm.
In general, FL algorithms have 4 main components:
A server-to-client broadcast step.
A local client update step.
A client-to-server upload step.
A server update step.
In TFF, we generally represent federated algorithms as an IterativeProcess
. This is simply a class that contains an initialize_fn
and a next_fn
. The initialize_fn
will be used to initialize the server, and the next_fn
will perform one communication round of federated averaging. Let's write a skeleton of what our iterative process for FedAvg should look like.
First, we have an initialize function that simply creates a tff.learning.Model
, and returns its trainable weights.
This function looks good, but as we will see later, we will need to make a small modification to make it a TFF computation.
We also want to sketch the next_fn
.
We'll focus on implementing these four components separately. We'll first focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.
TensorFlow Blocks
Client update
We will use our tff.learning.Model
to do client training in essentially the same way you would train a TF model. In particular, we will use tf.GradientTape
to compute the gradient on batches of data, then apply these gradient using a client_optimizer
.
Note that each tff.learning.Model
instance has a weights
attribute with two sub-attributes:
trainable
: A list of the tensors corresponding to trainable layers.non_trainable
: A list of the tensors corresponding to non-trainable layers.
For our purposes, we will only use the trainable weights (as our model only has those!).
Server Update
The server update will require even less effort. We will implement vanilla federated averaging, in which we simply replace the server model weights by the average of the client model weights. Again, we will only focus on the trainable weights.
Note that the code snippet above is clearly overkill, as we could simply return mean_client_weights
. However, more advanced implementations of Federated Averaging could use mean_client_weights
with more sophisticated techniques, such as momentum or adaptivity.
So far, we've only written pure TensorFlow code. This is by design, as TFF allows you to use much of the TensorFlow code you're already familiar with. However, now we have to specify the orchestration logic, that is, the logic that dictates what the server broadcasts to the client, and what the client uploads to the server.
This will require the "Federated Core" of TFF.
Introduction to the Federated Core
The Federated Core (FC) is a set of lower-level interfaces that serve as the foundation for the tff.learning
API. However, these interfaces are not limited to learning. In fact, they can be used for analytics and many other computations over distributed data.
At a high-level, the federated core is a development environment that enables compactly expressed program logic to combine TensorFlow code with distributed communication operators (such as distributed sums and broadcasts). The goal is to give researchers and practitioners expliict control over the distributed communication in their systems, without requiring system implementation details (such as specifying point-to-point network message exchanges).
One key point is that TFF is designed for privacy-preservation. Therefore, it allows explicit control over where data resides, to prevent unwanted accumulation of data at the centralized server location.
Federated data
Similar to "Tensor" concept in TensorFlow, which is one of the fundamental concepts, a key concept in TFF is "federated data", which refers to a collection of data items hosted across a group of devices in a distributed system (eg. client datasets, or the server model weights). We model the entire collection of data items across all devices as a single federated value.
For example, suppose we have client devices that each have a float representing the temperature of a sensor. We could represent it as a federated float by
Federated types are specified by a type T
of its member constituents (eg. tf.float32
) and a group G
of devices. We will focus on the cases where G
is either tff.CLIENTS
or tff.SERVER
. Such a federated type is represented as {T}@G
, as shown below.
Why do we care so much about placements? A key goal of TFF is to enable writing code that could be deployed on a real distributed system. This means that it is vital to reason about which subsets of devices execute which code, and where different pieces of data reside.
TFF focuses on three things: data, where the data is placed, and how the data is being transformed. The first two are encapsulated in federated types, while the last is encapsulated in federated computations.
Federated computations
TFF is a strongly-typed functional programming environment whose basic units are federated computations. These are pieces of logic that accept federated values as input, and return federated values as output.
For example, suppose we wanted to average the temperatures on our client sensors. We could define the following (using our federated float):
You might ask, how is this different from the tf.function
decorator in TensorFlow? The key answer is that the code generated by tff.federated_computation
is neither TensorFlow nor Python code; It is a specification of a distributed system in an internal platform-independent glue language.
While this may sound complicated, you can think of TFF computations as functions with well-defined type signatures. These type signatures can be directly queried.
This tff.federated_computation
accepts arguments of federated type {float32}@CLIENTS
, and returns values of federated type {float32}@SERVER
. Federated computations may also go from server to client, from client to client, or from server to server. Federated computations can also be composed like normal functions, as long as their type signatures match up.
To support development, TFF allows you to invoke a tff.federated_computation
as a Python function. For example, we can call
Non-eager computations and TensorFlow
There are two key restrictions to be aware of. First, when the Python interpreter encounters a tff.federated_computation
decorator, the function is traced once and serialized for future use. Therefore, TFF computations are fundamentally non-eager. This behavior is somewhat analogous to that of the tf.function
decorator in TensorFlow.
Second, a federated computation can only consist of federated operators (such as tff.federated_mean
), they cannot contain TensorFlow operations. TensorFlow code must be confined to blocks decorated with tff.tf_computation
. Most ordinary TensorFlow code can be directly decotrated, such as the following function that takes a number and adds 0.5
to it.
These also have type signatures, but without placements. For example, we can call
Here we see an important difference between tff.federated_computation
and tff.tf_computation
. The former has explicit placements, while the latter does not.
We can use tff.tf_computation
blocks in federated computations by specifying placements. Let's create a function that adds half, but only to federated floats at the clients. We can do this by using tff.federated_map
, which applies a given tff.tf_computation
, while preserving the placement.
This function is almost identical to add_half
, except that it only accepts values with placement at tff.CLIENTS
, and returns values with the same placement. We can see this in its type signature:
In summary:
TFF operates on federated values.
Each federated value has a federated type, with a type (eg.
tf.float32
) and a placement (eg.tff.CLIENTS
).Federated values can be transformed using federated computations, which must be decorated with
tff.federated_computation
and a federated type signature.TensorFlow code must be contained in blocks with
tff.tf_computation
decorators.These blocks can then be incorporated into federated computations.
Building your own FL Algorithm (Part 2)
Now that we've peeked at the Federated Core, we can build our own federated learning algorithm. Remember that above, we defined an initialize_fn
and next_fn
for our algorithm. The next_fn
will make use of the client_update
and server_update
we defined using pure TensorFlow code.
However, in order to make our algorithm a federated computation, we will need both the next_fn
and initialize_fn
to be tff.federated_computations
.
TensorFlow Federated blocks
Creating the initialization computation
The initialize function will be quite simple: We will create a model using model_fn
. However, remember that we must separate out our TensorFlow code using tff.tf_computation
.
We can then pass this directly into a federated computation using tff.federated_value
.
Creating the next_fn
We now use our client and server update code to write the actual algorithm. We will first turn our client_update
into a tff.tf_computation
that accepts a client datasets and server weights, and outputs an updated client weights tensor.
We will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model.
Let's look at the dataset type signature. Remember that we took 28 by 28 images (with integer labels) and flattened them.
We can also extract the model weights type by using our server_init
function above.
Examining the type signature, we'll be able to see the architecture of our model!
We can now create our tff.tf_computation
for the client update.
The tff.tf_computation
version of the server update can be defined in a similar way, using types we've already extracted.
Last, but not least, we need to create the tff.federated_computation
that brings this all together. This function will accept two federated values, one corresponding to the server weights (with placement tff.SERVER
), and the other corresponding to the client datasets (with placement tff.CLIENTS
).
Note that both these types were defined above! We simply need to give them the proper placement using `tff.type_at_{server/clients}``.
Remember the 4 elements of an FL algorithm?
A server-to-client broadcast step.
A local client update step.
A client-to-server upload step.
A server update step.
Now that we've built up the above, each part can be compactly represented as a single line of TFF code. This simplicity is why we had to take extra care to specify things such as federated types!
We now have a tff.federated_computation
for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, we pass these into tff.templates.IterativeProcess
.
Let's look at the type signature of the initialize
and next
functions of our iterative process.
This reflects the fact that federated_algorithm.initialize
is a no-arg function that returns a single-layer model (with a 784-by-10 weight matrix, and 10 bias units).
Here, we see that federated_algorithm.next
accepts a server model and client data, and returns an updated server model.
Evaluating the algorithm
Let's run a few rounds, and see how the loss changes. First, we will define an evaluation function using the centralized approach discussed in the second tutorial.
We first create a centralized evaluation dataset, and then apply the same preprocessing we used for the training data.
Note that we only take
the first 1000 elements for reasons of computational efficiency, but typically we'd use the entire test dataset.
Next, we write a function that accepts a server state, and uses Keras to evaluate on the test dataset. If you're familiar with tf.Keras
, this will all look familiar, though note the use of set_weights
!
Now, let's initialize our algorithm and evaluate on the test set.
Let's train for a few rounds and see if anything changes.
We see a slight decrease in the loss function. While the jump is small, note that we've only performed 10 training rounds, and on a small subset of clients. To see better results, we may have to do hundreds if not thousands of rounds.
Modifying our algorithm
At this point, let's stop and think about what we've accomplished. We've implemented Federated Averaging directly by combining pure TensorFlow code (for the client and server updates) with federated computations from the Federated Core of TFF.
To perform more sophisticted learning, we can simply alter what we have above. In particular, by editing the pure TF code above, we can change how the client performs training, or how the server updates its model.
Challenge: Add gradient clipping to the client_update
function.