CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/accelerate_examples/simple_cv_example.ipynb
Views: 2542
Kernel: Python 3 (ipykernel)

Launching Multi-Node Training from a Jupyter Environment

Using the notebook_launcher to use Accelerate from inside a Jupyter Notebook

General Overview

This notebook covers how to run the cv_example.py script as a Jupyter Notebook and train it on a distributed system. It will also cover the few specific requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training.

Configuring the Environment

Before any training can be performed, an accelerate config file must exist in the system. Usually this can be done by running the following in a terminal:

accelerate config

However, if general defaults are fine and you are not running on a TPU, accelerate has a utility to quickly write your GPU configuration into a config file via write_basic_config.

The following cell will restart Jupyter after writing the configuration, as CUDA code was called to perform this. CUDA can't be initialized more than once (once for the single-GPU's notebooks use by default, and then what would be again when notebook_launcher is called). It's fine to debug in the notebook and have calls to CUDA, but remember that in order to finally train a full cleanup and restart will need to be performed, such as what is shown below:

#import os #from accelerate.utils import write_basic_config #write_basic_config() # Write a config file #os._exit(00) # Restart the notebook

Preparing the Dataset and Model

Next you should prepare your dataset. As mentioned at earlier, great care should be taken when preparing the DataLoaders and model to make sure that nothing is put on any GPU.

If you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later.

Make sure the dataset is downloaded based on the directions here

import os, re, torch, PIL import numpy as np from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor from accelerate import Accelerator from accelerate.utils import set_seed from timm import create_model

First we'll create a function to extract the class name based on a file:

import os data_dir = "../../images" fnames = os.listdir(data_dir) fname = fnames[0] print(fname)
beagle_32.jpg

In the case here, the label is beagle:

import re def extract_label(fname): stem = fname.split(os.path.sep)[-1] return re.search(r"^(.*)_\d+\.jpg$", stem).groups()[0]
extract_label(fname)
'beagle'

Next we'll create a Dataset class:

class PetsDataset(Dataset): def __init__(self, file_names, image_transform=None, label_to_id=None): self.file_names = file_names self.image_transform = image_transform self.label_to_id = label_to_id def __len__(self): return len(self.file_names) def __getitem__(self, idx): fname = self.file_names[idx] raw_image = PIL.Image.open(fname) image = raw_image.convert("RGB") if self.image_transform is not None: image = self.image_transform(image) label = extract_label(fname) if self.label_to_id is not None: label = self.label_to_id[label] return {"image": image, "label": label}

And build our dataset

# Grab all the image filenames fnames = [ os.path.join(data_dir, fname) for fname in fnames if fname.endswith(".jpg") ] # Build the labels all_labels = [ extract_label(fname) for fname in fnames ] id_to_label = list(set(all_labels)) id_to_label.sort() label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}

Note: This will be stored inside of a function as we'll be setting our seed during training.

def get_dataloaders(batch_size:int=64): "Builds a set of dataloaders with a batch_size" random_perm = np.random.permutation(len(fnames)) cut = int(0.8 * len(fnames)) train_split = random_perm[:cut] eval_split = random_perm[:cut] # For training we use a simple RandomResizedCrop train_tfm = Compose([ RandomResizedCrop((224, 224), scale=(0.5, 1.0)), ToTensor() ]) train_dataset = PetsDataset( [fnames[i] for i in train_split], image_transform=train_tfm, label_to_id=label_to_id ) # For evaluation we use a deterministic Resize eval_tfm = Compose([ Resize((224, 224)), ToTensor() ]) eval_dataset = PetsDataset( [fnames[i] for i in eval_split], image_transform=eval_tfm, label_to_id=label_to_id ) # Instantiate dataloaders train_dataloader = DataLoader( train_dataset, shuffle=True, batch_size=batch_size, num_workers=4 ) eval_dataloader = DataLoader( eval_dataset, shuffle=False, batch_size=batch_size*2, num_workers=4 ) return train_dataloader, eval_dataloader

Writing the Training Function

Now we can build our training loop. notebook_launcher works by passing in a function to call that will be ran across the distributed system.

Here is a basic training loop for our animal classification problem:

from torch.optim.lr_scheduler import CosineAnnealingLR
def training_loop(mixed_precision="fp16", seed:int=42, batch_size:int=64): set_seed(seed) # Initialize accelerator accelerator = Accelerator(mixed_precision=mixed_precision) # Build dataloaders train_dataloader, eval_dataloader = get_dataloaders(batch_size) # instantiate the model (we build the model here so that the seed also controls new weight initaliziations) model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id)) # Freeze the base model for param in model.parameters(): param.requires_grad=False for param in model.get_classifier().parameters(): param.requires_grad=True # We normalize the batches of images to be a bit faster mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None] std = torch.tensor(model.default_cfg["std"])[None, :, None, None] # To make this constant available on the active device, we set it to the accelerator device mean = mean.to(accelerator.device) std = std.to(accelerator.device) # Intantiate the optimizer optimizer = torch.optim.Adam(params=model.parameters(), lr = 3e-2/25) # Instantiate the learning rate scheduler lr_scheduler = OneCycleLR( optimizer=optimizer, max_lr=3e-2, epochs=5, steps_per_epoch=len(train_dataloader) ) # Prepare everything # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the # prepare method. model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) # Now we train the model for epoch in range(5): model.train() for step, batch in enumerate(train_dataloader): # We could avoid this line since we set the accelerator with `device_placement=True`. batch = {k: v.to(accelerator.device) for k, v in batch.items()} inputs = (batch["image"] - mean) / std outputs = model(inputs) loss = torch.nn.functional.cross_entropy(outputs, batch["label"]) accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() model.eval() accurate = 0 num_elems = 0 for _, batch in enumerate(eval_dataloader): # We could avoid this line since we set the accelerator with `device_placement=True`. batch = {k: v.to(accelerator.device) for k, v in batch.items()} inputs = (batch["image"] - mean) / std with torch.no_grad(): outputs = model(inputs) predictions = outputs.argmax(dim=-1) accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"]) num_elems += accurate_preds.shape[0] accurate += accurate_preds.long().sum() eval_metric = accurate.item() / num_elems # Use accelerator.print to print only on the main process. accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")

All that's left is to use the notebook_launcher.

We pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the documentation for more information)

from accelerate import notebook_launcher
args = ("fp16", 42, 64) notebook_launcher(training_loop, args, num_processes=2)
Launching training on 2 GPUs. epoch 0: 88.12 epoch 1: 91.73 epoch 2: 92.58 epoch 3: 93.90 epoch 4: 94.71

And that's it!

Conclusion

This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:

  • Make sure to save any code that use CUDA (or CUDA imports) for the function passed to notebook_launcher

  • Set the num_processes to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)