Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hackassin
GitHub Repository: hackassin/learnopencv
Path: blob/master/Federated-Learning-Intro/Federated_Learning_Tutorial.ipynb
3118 views
Kernel: Python 3

Imports and creating workers

We use all the basic imports that we normally require while doing any deep learning problem with PyTorch.

The thing we need extra is the PySyft and hooking it onto PyTorch to add all the extra goodness we need for federated learning to work, as we discussed in the introduction to API section.

import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import logging # import Pysyft to help us to simulate federated leraning import syft as sy # hook PyTorch to PySyft i.e. add extra functionalities to support Federated Learning # and other private AI tools hook = sy.TorchHook(torch) # we create two imaginary schools westside_school = sy.VirtualWorker(hook, id="westside") grapevine_high = sy.VirtualWorker(hook, id="grapevine")

Args

Now we define hyper-parameters such as learning rate, batch size, test batch size etc.

# define the args args = { 'use_cuda' : True, 'batch_size' : 64, 'test_batch_size' : 1000, 'lr' : 0.01, 'log_interval' : 100, 'epochs' : 10 } # check to use GPU or not use_cuda = args['use_cuda'] and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu")

CNN Model

Now we define a very simple CNN.

# create a simple CNN net class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1), nn.ReLU(), nn.Conv2d(in_channels=32,out_channels = 64, kernel_size = 3, stride = 1), nn.ReLU() ) self.fc = nn.Sequential( nn.Linear(in_features=64*12*12, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=10), ) def forward(self, x): x = self.conv(x) x = F.max_pool2d(x,2) x = x.view(-1, 64*12*12) x = self.fc(x) x = F.log_softmax(x, dim=1) return x

Sending Data to schools

We load the data first and then transform the data into a federated dataset using .federate() method. It does a couple of things for us:

  • It splits the dataset in two parts (which was also done by the torch Data Loader as well)
  • But the extra thing it does is it also sends this data across two remote workers, in our case the two schools.

We will then used this newly created federated dataset to iterate over remote batches during our training loop.

# Now we take the help of PySyft's awesome API to prepare the data for us and # distribute for us across 2 workers ie. two schools # normally we dont have to distribute data, data is already there at the site. # We are doing this just to simulate federated learning. # Below code looks just like torch code with just some minor changes. This is what's nice about PySyft. federated_train_loader = sy.FederatedDataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) .federate((grapevine_high, westside_school)), batch_size=args['batch_size'], shuffle=True) # test data remains with us locally # this is the normal torch code to load test data from MNIST # that we are all familiar with test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args['test_batch_size'], shuffle=True)

What I do below is extract one pair of images,label batch to show they are pointers.

# we can look at the data, it is actually pointer tensors for images,labels in federated_train_loader: print(images) # batch of images pointers print(labels) # batch of image labels pointers print(len(images)) # len function works on pointers as well print(len(labels)) # we can see both are same, no of images as well as their labels break
(Wrapper)>[PointerTensor | me:46558879977 -> grapevine:64689167388] (Wrapper)>[PointerTensor | me:54846295197 -> grapevine:37439754962] 64 64

Train and Val

Now each time we train the model, we need to send it to the right location for each batch. We used .send() function that we learnt above to do this.

Then, we perform all the operations remotely with the same syntax like we're doing local PyTorch. When we're done, we get back the updated model using the .get() method.

Note in the below train function that (data, target) is a pair of PointerTensor. In a PointerTensor, we can get the worker it points to using the .location attribute, and that is what precisely we are using to send the model to the correct location.

def train(args, model, device, train_loader, optimizer, epoch): model.train() # iterate over federated data for batch_idx, (data, target) in enumerate(train_loader): # send the model to the remote location model = model.send(data.location) # the same torch code that we are use to data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) # this loss is a ptr to the tensor loss # at the remote location loss = F.nll_loss(output, target) # call backward() on the loss ptr, # that will send the command to call # backward on the actual loss tensor # present on the remote machine loss.backward() optimizer.step() # get back the updated model model.get() if batch_idx % args['log_interval'] == 0: # a thing to note is the variable loss was # also created at remote worker, so we need to # explicitly get it back loss = loss.get() print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * args['batch_size'], # no of images done len(train_loader) * args['batch_size'], # total images left 100. * batch_idx / len(train_loader), loss.item() ) )

The test function remains the same as it is run locally on our machine only whereas training happens remotely.

def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) # add losses together test_loss += F.nll_loss(output, target, reduction='sum').item() # get the index of the max probability class pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))

Start the training

We can now start training the model at last and the best part is, we use the same code when we train the model locally. Using the exact same code as explained in this notebook, I was able to get accuracy of 98% which is quite good.

model = Net().to(device) optimizer = optim.SGD(model.parameters(), lr=args['lr']) logging.info("Starting training !!") for epoch in range(1, args['epochs'] + 1): train(args, model, device, federated_train_loader, optimizer, epoch) test(model, device, test_loader) # thats all we need to do XD
Train Epoch: 1 [0/60032 (0%)] Loss: 2.309102 Train Epoch: 1 [6400/60032 (11%)] Loss: 0.470080 Train Epoch: 1 [12800/60032 (21%)] Loss: 0.641319 Train Epoch: 1 [19200/60032 (32%)] Loss: 0.283288 Train Epoch: 1 [25600/60032 (43%)] Loss: 0.244961 Train Epoch: 1 [32000/60032 (53%)] Loss: 0.203230 Train Epoch: 1 [38400/60032 (64%)] Loss: 0.256210 Train Epoch: 1 [44800/60032 (75%)] Loss: 0.113366 Train Epoch: 1 [51200/60032 (85%)] Loss: 0.292694 Train Epoch: 1 [57600/60032 (96%)] Loss: 0.234568 Test set: Average loss: 0.1980, Accuracy: 9412/10000 (94%) Train Epoch: 2 [0/60032 (0%)] Loss: 0.121836 Train Epoch: 2 [6400/60032 (11%)] Loss: 0.173105 Train Epoch: 2 [12800/60032 (21%)] Loss: 0.098377 Train Epoch: 2 [19200/60032 (32%)] Loss: 0.241032
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-10-6ae891d3f3f4> in <module> 5 6 for epoch in range(1, args['epochs'] + 1): ----> 7 train(args, model, device, federated_train_loader, optimizer, epoch) 8 test(model, device, test_loader) 9 <ipython-input-6-0f29807e2837> in train(args, model, device, train_loader, optimizer, epoch) 3 4 # iterate over federated data ----> 5 for batch_idx, (data, target) in enumerate(train_loader): 6 7 # send the model to the remote location c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in __next__(self) 247 else: 248 iterator = self.iterators[0] --> 249 data, target = next(iterator) 250 return data, target 251 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in __next__(self) 100 101 def __next__(self): --> 102 batch = self._get_batch() 103 return batch 104 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in _get_batch(self) 83 try: 84 indices = next(self.sample_iter[worker]) ---> 85 batch = self.collate_fn([self.federated_dataset[worker][i] for i in indices]) 86 return batch 87 # All the data for this worker has been used c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataloader.py in <listcomp>(.0) 83 try: 84 indices = next(self.sample_iter[worker]) ---> 85 batch = self.collate_fn([self.federated_dataset[worker][i] for i in indices]) 86 return batch 87 # All the data for this worker has been used c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\frameworks\torch\fl\dataset.py in __getitem__(self, index) 54 data_elem = torch.tensor(self.transform_(data_elem.numpy())) 55 ---> 56 return data_elem, self.targets[index] 57 58 def transform(self, transform): c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\trace.py in trace_wrapper(*args, **kwargs) 81 syft.hook.trace.logs.append((command, response)) 82 else: ---> 83 response = func(*args, **kwargs) 84 85 return response c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\hook.py in overloaded_native_method(self, *args, **kwargs) 456 # Send the new command to the appropriate class and get the response 457 method = getattr(new_self, method_name) --> 458 response = method(*new_args, **new_kwargs) 459 460 # For inplace methods, just directly return self c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\generic\frameworks\hook\hook.py in overloaded_pointer_method(self, *args, **kwargs) 619 command = (attr, self, args, kwargs) 620 --> 621 response = owner.send_command(location, command) 622 623 # For inplace methods, just directly return self c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\workers\base.py in send_command(self, recipient, message, return_ids) 581 try: 582 message = TensorCommandMessage.computation(name, target, args_, kwargs_, return_ids) --> 583 ret_val = self.send_msg(message, location=recipient) 584 except ResponseSignatureError as e: 585 ret_val = None c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\workers\base.py in send_msg(self, message, location) 282 283 # Step 1: serialize the message to a binary --> 284 bin_message = sy.serde.serialize(message, worker=self) 285 286 # Step 2: send the message and wait for a response c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\serde.py in serialize(obj, worker, simplified, force_full_simplification, strategy) 41 binary: the serialized form of the object. 42 """ ---> 43 return strategy(obj, worker, simplified, force_full_simplification) 44 45 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\msgpack\serde.py in serialize(obj, worker, simplified, force_full_simplification) 333 334 simple_objects = _serialize_msgpack_simple(obj, worker, simplified, force_full_simplification) --> 335 return _serialize_msgpack_binary(simple_objects) 336 337 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\msgpack\serde.py in _serialize_msgpack_binary(simple_objects, worker, simplified, force_full_simplification) 298 # even if compressed flag is set to false by the caller we 299 # output the input stream as it is with header set to '0' --> 300 return compression._compress(binary) 301 302 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in _compress(decompressed_input_bin) 88 89 """ ---> 90 compress_stream, compress_scheme = _apply_compress_scheme(decompressed_input_bin) 91 try: 92 z = scheme_to_bytes[compress_scheme] + compress_stream c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in _apply_compress_scheme(decompressed_input_bin) 32 decompressed_input_bin: the binary to be compressed 33 """ ---> 34 return apply_lz4_compression(decompressed_input_bin) 35 36 c:\users\jatin prakash\custom_envs\base\lib\site-packages\syft\serde\compression.py in apply_lz4_compression(decompressed_input_bin) 59 a tuple (compressed_result, LZ4) 60 """ ---> 61 return lz4.frame.compress(decompressed_input_bin), LZ4 62 63 KeyboardInterrupt: