Path: blob/master/Federated-Learning-Intro/Federated_Learning_Tutorial.ipynb
3118 views
Imports and creating workers
We use all the basic imports that we normally require while doing any deep learning problem with PyTorch.
The thing we need extra is the PySyft and hooking it onto PyTorch to add all the extra goodness we need for federated learning to work, as we discussed in the introduction to API section.
Args
Now we define hyper-parameters such as learning rate, batch size, test batch size etc.
CNN Model
Now we define a very simple CNN.
Sending Data to schools
We load the data first and then transform the data into a federated dataset using .federate()
method. It does a couple of things for us:
- It splits the dataset in two parts (which was also done by the torch Data Loader as well)
- But the extra thing it does is it also sends this data across two remote workers, in our case the two schools.
We will then used this newly created federated dataset to iterate over remote batches during our training loop.
What I do below is extract one pair of images,label batch to show they are pointers.
Train and Val
Now each time we train the model, we need to send it to the right location for each batch. We used .send()
function that we learnt above to do this.
Then, we perform all the operations remotely with the same syntax like we're doing local PyTorch. When we're done, we get back the updated model using the .get()
 method.
Note in the below train function that (data, target)
is a pair of PointerTensor. In a PointerTensor, we can get the worker it points to using the .location
attribute, and that is what precisely we are using to send the model to the correct location.
The test function remains the same as it is run locally on our machine only whereas training happens remotely.
Start the training
We can now start training the model at last and the best part is, we use the same code when we train the model locally. Using the exact same code as explained in this notebook, I was able to get accuracy of 98% which is quite good.
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-10-6ae891d3f3f4> in <module>
5
6 for epoch in range(1, args['epochs'] + 1):
----> 7 train(args, model, device, federated_train_loader, optimizer, epoch)
8 test(model, device, test_loader)
9
<ipython-input-6-0f29807e2837> in train(args, model, device, train_loader, optimizer, epoch)
3
4 # iterate over federated data
----> 5 for batch_idx, (data, target) in enumerate(train_loader):
6
7 # send the model to the remote location
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in __next__(self)
247 else:
248 iterator = self.iterators[0]
--> 249 data, target = next(iterator)
250 return data, target
251
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in __next__(self)
100
101 def __next__(self):
--> 102 batch = self._get_batch()
103 return batch
104
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in _get_batch(self)
83 try:
84 indices = next(self.sample_iter[worker])
---> 85 batch = self.collate_fn([self.federated_dataset[worker][i] for i in indices])
86 return batch
87 # All the data for this worker has been used
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in <listcomp>(.0)
83 try:
84 indices = next(self.sample_iter[worker])
---> 85 batch = self.collate_fn([self.federated_dataset[worker][i] for i in indices])
86 return batch
87 # All the data for this worker has been used
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataset.py in __getitem__(self, index)
54 data_elem = torch.tensor(self.transform_(data_elem.numpy()))
55
---> 56 return data_elem, self.targets[index]
57
58 def transform(self, transform):
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\trace.py in trace_wrapper(*args, **kwargs)
81 syft.hook.trace.logs.append((command, response))
82 else:
---> 83 response = func(*args, **kwargs)
84
85 return response
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\hook.py in overloaded_native_method(self, *args, **kwargs)
456 # Send the new command to the appropriate class and get the response
457 method = getattr(new_self, method_name)
--> 458 response = method(*new_args, **new_kwargs)
459
460 # For inplace methods, just directly return self
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\hook.py in overloaded_pointer_method(self, *args, **kwargs)
619 command = (attr, self, args, kwargs)
620
--> 621 response = owner.send_command(location, command)
622
623 # For inplace methods, just directly return self
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\workers\base.py in send_command(self, recipient, message, return_ids)
581 try:
582 message = TensorCommandMessage.computation(name, target, args_, kwargs_, return_ids)
--> 583 ret_val = self.send_msg(message, location=recipient)
584 except ResponseSignatureError as e:
585 ret_val = None
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\workers\base.py in send_msg(self, message, location)
282
283 # Step 1: serialize the message to a binary
--> 284 bin_message = sy.serde.serialize(message, worker=self)
285
286 # Step 2: send the message and wait for a response
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\serde.py in serialize(obj, worker, simplified, force_full_simplification, strategy)
41 binary: the serialized form of the object.
42 """
---> 43 return strategy(obj, worker, simplified, force_full_simplification)
44
45
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\msgpack\serde.py in serialize(obj, worker, simplified, force_full_simplification)
333
334 simple_objects = _serialize_msgpack_simple(obj, worker, simplified, force_full_simplification)
--> 335 return _serialize_msgpack_binary(simple_objects)
336
337
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\msgpack\serde.py in _serialize_msgpack_binary(simple_objects, worker, simplified, force_full_simplification)
298 # even if compressed flag is set to false by the caller we
299 # output the input stream as it is with header set to '0'
--> 300 return compression._compress(binary)
301
302
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in _compress(decompressed_input_bin)
88
89 """
---> 90 compress_stream, compress_scheme = _apply_compress_scheme(decompressed_input_bin)
91 try:
92 z = scheme_to_bytes[compress_scheme] + compress_stream
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in _apply_compress_scheme(decompressed_input_bin)
32 decompressed_input_bin: the binary to be compressed
33 """
---> 34 return apply_lz4_compression(decompressed_input_bin)
35
36
c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in apply_lz4_compression(decompressed_input_bin)
59 a tuple (compressed_result, LZ4)
60 """
---> 61 return lz4.frame.compress(decompressed_input_bin), LZ4
62
63
KeyboardInterrupt: