Path: blob/main/beginner_source/hyperparameter_tuning_tutorial.py
3697 views
"""1Hyperparameter tuning using Ray Tune2====================================34**Author:** `Ricardo Decal <https://github.com/crypdick>`__56This tutorial shows how to integrate Ray Tune into your PyTorch training7workflow to perform scalable and efficient hyperparameter tuning.89.. grid:: 21011.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn12:class-card: card-prerequisites1314* How to modify a PyTorch training loop for Ray Tune15* How to scale a hyperparameter sweep to multiple nodes and GPUs without code changes16* How to define a hyperparameter search space and run a sweep with ``tune.Tuner``17* How to use an early-stopping scheduler (ASHA) and report metrics/checkpoints18* How to use checkpointing to resume training and load the best model1920.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites21:class-card: card-prerequisites2223* PyTorch v2.9+ and ``torchvision``24* Ray Tune (``ray[tune]``) v2.52.1+25* GPU(s) are optional, but recommended for faster training2627`Ray <https://docs.ray.io/en/latest/index.html>`__, a project of the28PyTorch Foundation, is an open source unified framework for scaling AI29and Python applications. It helps run distributed jobs by handling the30complexity of distributed computing. `Ray31Tune <https://docs.ray.io/en/latest/tune/index.html>`__ is a library32built on Ray for hyperparameter tuning that enables you to scale a33hyperparameter sweep from your machine to a large cluster with no code34changes.3536This tutorial adapts the `PyTorch tutorial for training a CIFAR1037classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__38to run multi-GPU hyperparameter sweeps with Ray Tune.3940Setup41-----4243To run this tutorial, install the following dependencies:4445.. code-block:: bash4647pip install "ray[tune]" torchvision4849"""5051######################################################################52# Then start with the imports:5354from functools import partial55import os56import tempfile57from pathlib import Path58import torch59import torch.nn as nn60import torch.nn.functional as F61import torch.optim as optim62from torch.utils.data import random_split63import torchvision64import torchvision.transforms as transforms65# New: imports for Ray Tune66import ray67from ray import tune68from ray.tune import Checkpoint69from ray.tune.schedulers import ASHAScheduler7071######################################################################72# Data loading73# ============74#75# Wrap the data loaders in a constructor function. In this tutorial, a76# global data directory is passed to the function to enable reusing the77# dataset across different trials. In a cluster environment, you can use78# shared storage, such as network file systems, to prevent each node from79# downloading the data separately.8081def load_data(data_dir="./data"):82# Mean and standard deviation of the CIFAR10 training subset.83transform = transforms.Compose(84[transforms.ToTensor(), transforms.Normalize((0.4914, 0.48216, 0.44653), (0.2022, 0.19932, 0.20086))]85)8687trainset = torchvision.datasets.CIFAR10(88root=data_dir, train=True, download=True, transform=transform89)9091testset = torchvision.datasets.CIFAR10(92root=data_dir, train=False, download=True, transform=transform93)9495return trainset, testset9697######################################################################98# Model architecture99# ==================100#101# This tutorial searches for the best sizes for the fully connected layers102# and the learning rate. To enable this, the ``Net`` class exposes the103# layer sizes ``l1`` and ``l2`` as configurable parameters that Ray Tune104# can search over:105106class Net(nn.Module):107def __init__(self, l1=120, l2=84):108super().__init__()109self.conv1 = nn.Conv2d(3, 6, 5)110self.pool = nn.MaxPool2d(2, 2)111self.conv2 = nn.Conv2d(6, 16, 5)112self.fc1 = nn.Linear(16 * 5 * 5, l1)113self.fc2 = nn.Linear(l1, l2)114self.fc3 = nn.Linear(l2, 10)115116def forward(self, x):117x = self.pool(F.relu(self.conv1(x)))118x = self.pool(F.relu(self.conv2(x)))119x = torch.flatten(x, 1) # flatten all dimensions except batch120x = F.relu(self.fc1(x))121x = F.relu(self.fc2(x))122x = self.fc3(x)123return x124125######################################################################126# Define the search space127# =======================128#129# Next, define the hyperparameters to tune and how Ray Tune samples them.130# Ray Tune offers a variety of `search space131# distributions <https://docs.ray.io/en/latest/tune/api/search_space.html>`__132# to suit different parameter types: ``loguniform``, ``uniform``,133# ``choice``, ``randint``, ``grid``, and more. You can also express134# complex dependencies between parameters with `conditional search135# spaces <https://docs.ray.io/en/latest/tune/tutorials/tune-search-spaces.html#how-to-use-custom-and-conditional-search-spaces-in-tune>`__136# or sample from arbitrary functions.137#138# Here is the search space for this tutorial:139#140# .. code-block:: python141#142# config = {143# "l1": tune.choice([2**i for i in range(9)]),144# "l2": tune.choice([2**i for i in range(9)]),145# "lr": tune.loguniform(1e-4, 1e-1),146# "batch_size": tune.choice([2, 4, 8, 16]),147# }148#149# The ``tune.choice()`` accepts a list of values that are uniformly150# sampled from. In this example, the ``l1`` and ``l2`` parameter values151# are powers of 2 between 1 and 256, and the learning rate samples on a152# log scale between 0.0001 and 0.1. Sampling on a log scale enables153# exploration across a range of magnitudes on a relative scale, rather154# than an absolute scale.155#156# Training function157# =================158#159# Ray Tune requires a training function that accepts a configuration160# dictionary and runs the main training loop. As Ray Tune runs different161# trials, it updates the configuration dictionary for each trial.162#163# Here is the full training function, followed by explanations of the key164# Ray Tune integration points:165166def train_cifar(config, data_dir=None):167net = Net(config["l1"], config["l2"])168device = config["device"]169170net = net.to(device)171if torch.cuda.device_count() > 1:172net = nn.DataParallel(net)173174criterion = nn.CrossEntropyLoss()175optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)176177# Load checkpoint if resuming training178checkpoint = tune.get_checkpoint()179if checkpoint:180with checkpoint.as_directory() as checkpoint_dir:181checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"182checkpoint_state = torch.load(checkpoint_path)183start_epoch = checkpoint_state["epoch"]184net.load_state_dict(checkpoint_state["net_state_dict"])185optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])186else:187start_epoch = 0188189trainset, _testset = load_data(data_dir)190191test_abs = int(len(trainset) * 0.8)192train_subset, val_subset = random_split(193trainset, [test_abs, len(trainset) - test_abs]194)195196trainloader = torch.utils.data.DataLoader(197train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8198)199valloader = torch.utils.data.DataLoader(200val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8201)202203for epoch in range(start_epoch, 10): # loop over the dataset multiple times204running_loss = 0.0205epoch_steps = 0206for i, data in enumerate(trainloader, 0):207# get the inputs; data is a list of [inputs, labels]208inputs, labels = data209inputs, labels = inputs.to(device), labels.to(device)210211# zero the parameter gradients212optimizer.zero_grad()213214# forward + backward + optimize215outputs = net(inputs)216loss = criterion(outputs, labels)217loss.backward()218optimizer.step()219220# print statistics221running_loss += loss.item()222epoch_steps += 1223if i % 2000 == 1999: # print every 2000 mini-batches224print(225"[%d, %5d] loss: %.3f"226% (epoch + 1, i + 1, running_loss / epoch_steps)227)228running_loss = 0.0229230# Validation loss231val_loss = 0.0232val_steps = 0233total = 0234correct = 0235for i, data in enumerate(valloader, 0):236with torch.no_grad():237inputs, labels = data238inputs, labels = inputs.to(device), labels.to(device)239240outputs = net(inputs)241_, predicted = torch.max(outputs.data, 1)242total += labels.size(0)243correct += (predicted == labels).sum().item()244245loss = criterion(outputs, labels)246val_loss += loss.cpu().numpy()247val_steps += 1248249# Save checkpoint and report metrics250checkpoint_data = {251"epoch": epoch,252"net_state_dict": net.state_dict(),253"optimizer_state_dict": optimizer.state_dict(),254}255with tempfile.TemporaryDirectory() as checkpoint_dir:256checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"257torch.save(checkpoint_data, checkpoint_path)258259checkpoint = Checkpoint.from_directory(checkpoint_dir)260tune.report(261{"loss": val_loss / val_steps, "accuracy": correct / total},262checkpoint=checkpoint,263)264265print("Finished Training")266267######################################################################268# Key integration points269# ----------------------270#271# Using hyperparameters from the configuration dictionary272# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~273#274# Ray Tune updates the ``config`` dictionary with the hyperparameters for275# each trial. In this example, the model architecture and optimizer276# receive the hyperparameters from the ``config`` dictionary:277#278# .. code-block:: python279#280# net = Net(config["l1"], config["l2"])281# optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)282#283# Reporting metrics and saving checkpoints284# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~285#286# The most important integration is communicating with Ray Tune. Ray Tune287# uses the validation metrics to determine the best hyperparameter288# configuration and to stop underperforming trials early, saving289# resources.290#291# Checkpointing enables you to later load the trained models, resume292# hyperparameter searches, and provides fault tolerance. It’s also293# required for some Ray Tune schedulers like `Population Based294# Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`__295# that pause and resume trials during the search.296#297# This code from the training function loads model and optimizer state at298# the start if a checkpoint exists:299#300# .. code-block:: python301#302# checkpoint = tune.get_checkpoint()303# if checkpoint:304# with checkpoint.as_directory() as checkpoint_dir:305# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"306# checkpoint_state = torch.load(checkpoint_path)307# start_epoch = checkpoint_state["epoch"]308# net.load_state_dict(checkpoint_state["net_state_dict"])309# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])310#311# At the end of each epoch, save a checkpoint and report the validation312# metrics:313#314# .. code-block:: python315#316# checkpoint_data = {317# "epoch": epoch,318# "net_state_dict": net.state_dict(),319# "optimizer_state_dict": optimizer.state_dict(),320# }321# with tempfile.TemporaryDirectory() as checkpoint_dir:322# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"323# torch.save(checkpoint_data, checkpoint_path)324#325# checkpoint = Checkpoint.from_directory(checkpoint_dir)326# tune.report(327# {"loss": val_loss / val_steps, "accuracy": correct / total},328# checkpoint=checkpoint,329# )330#331# Ray Tune checkpointing supports local file systems, cloud storage, and332# distributed file systems. For more information, see the `Ray Tune333# storage334# documentation <https://docs.ray.io/en/latest/tune/tutorials/tune-storage.html>`__.335#336# Multi-GPU support337# ~~~~~~~~~~~~~~~~~338#339# Image classification models can be greatly accelerated by using GPUs.340# The training function supports multi-GPU training by wrapping the model341# in ``nn.DataParallel``:342#343# .. code-block:: python344#345# if torch.cuda.device_count() > 1:346# net = nn.DataParallel(net)347#348# This training function supports training on CPUs, a single GPU, multiple GPUs, or349# multiple nodes without code changes. Ray Tune automatically distributes the trials350# across the nodes according to the available resources. Ray Tune also supports `fractional351# GPUs <https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html#fractional-accelerators>`__352# so that one GPU can be shared among multiple trials, provided that the353# models, optimizers, and data batches fit into the GPU memory.354#355# Validation split356# ~~~~~~~~~~~~~~~~357#358# The original CIFAR10 dataset only has train and test subsets. This is359# sufficient for training a single model, however for hyperparameter360# tuning a validation subset is required. The training function creates a361# validation subset by reserving 20% of the training subset. The test362# subset is used to evaluate the best model’s generalization error after363# the search completes.364#365# Evaluation function366# ===================367#368# After finding the optimal hyperparameters, test the model on a held-out369# test set to estimate the generalization error:370371def test_accuracy(net, device="cpu", data_dir=None):372_trainset, testset = load_data(data_dir)373374testloader = torch.utils.data.DataLoader(375testset, batch_size=4, shuffle=False, num_workers=2376)377378correct = 0379total = 0380with torch.no_grad():381for data in testloader:382image_batch, labels = data383image_batch, labels = image_batch.to(device), labels.to(device)384outputs = net(image_batch)385_, predicted = torch.max(outputs.data, 1)386total += labels.size(0)387correct += (predicted == labels).sum().item()388389return correct / total390391######################################################################392# Configure and run Ray Tune393# ==========================394#395# With the training and evaluation functions defined, configure Ray Tune396# to run the hyperparameter search.397#398# Scheduler for early stopping399# ----------------------------400#401# Ray Tune provides schedulers to improve the efficiency of the402# hyperparameter search by detecting underperforming trials and stopping403# them early. The ``ASHAScheduler`` uses the Asynchronous Successive404# Halving Algorithm (ASHA) to aggressively terminate low-performing405# trials:406#407# .. code-block:: python408#409# scheduler = ASHAScheduler(410# max_t=max_num_epochs,411# grace_period=1,412# reduction_factor=2,413# )414#415# Ray Tune also provides `advanced search416# algorithms <https://docs.ray.io/en/latest/tune/api/suggestion.html>`__417# to smartly pick the next set of hyperparameters based on previous418# results, instead of relying only on random or grid search. Examples419# include420# `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__421# and422# `BayesOpt <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__.423#424# Resource allocation425# -------------------426#427# Tell Ray Tune what resources to allocate for each trial by passing a428# ``resources`` dictionary to ``tune.with_resources``:429#430# .. code-block:: python431#432# tune.with_resources(433# partial(train_cifar, data_dir=data_dir),434# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}435# )436#437# Ray Tune automatically manages the placement of these trials and ensures438# that the trials run in isolation, so you don’t need to manually assign439# GPUs to processes.440#441# For example, if you are running this experiment on a cluster of 20442# machines, each with 8 GPUs, you can set ``gpus_per_trial = 0.5`` to443# schedule two concurrent trials per GPU. This configuration runs 320444# trials in parallel across the cluster.445#446# .. note::447# To run this tutorial without GPUs, set ``gpus_per_trial=0``448# and expect significantly longer runtimes.449#450# To avoid long runtimes during development, start with a small number451# of trials and epochs.452#453# Creating the Tuner454# ------------------455#456# The Ray Tune API is modular and composable. Pass your configuration to457# the ``tune.Tuner`` class to create a tuner object, then run458# ``tuner.fit()`` to start training:459#460# .. code-block:: python461#462# tuner = tune.Tuner(463# tune.with_resources(464# partial(train_cifar, data_dir=data_dir),465# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}466# ),467# tune_config=tune.TuneConfig(468# metric="loss",469# mode="min",470# scheduler=scheduler,471# num_samples=num_trials,472# ),473# param_space=config,474# )475# results = tuner.fit()476#477# After training completes, retrieve the best performing trial, load its478# checkpoint, and evaluate on the test set.479#480# Putting it all together481# -----------------------482483def main(num_trials=10, max_num_epochs=10, gpus_per_trial=0, cpus_per_trial=2):484print("Starting hyperparameter tuning.")485ray.init(include_dashboard=False)486487data_dir = os.path.abspath("./data")488load_data(data_dir) # Pre-download the dataset489device = "cuda" if torch.cuda.is_available() else "cpu"490config = {491"l1": tune.choice([2**i for i in range(9)]),492"l2": tune.choice([2**i for i in range(9)]),493"lr": tune.loguniform(1e-4, 1e-1),494"batch_size": tune.choice([2, 4, 8, 16]),495"device": device,496}497scheduler = ASHAScheduler(498max_t=max_num_epochs,499grace_period=1,500reduction_factor=2,501)502503tuner = tune.Tuner(504tune.with_resources(505partial(train_cifar, data_dir=data_dir),506resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}507),508tune_config=tune.TuneConfig(509metric="loss",510mode="min",511scheduler=scheduler,512num_samples=num_trials,513),514param_space=config,515)516results = tuner.fit()517518best_result = results.get_best_result("loss", "min")519print(f"Best trial config: {best_result.config}")520print(f"Best trial final validation loss: {best_result.metrics['loss']}")521print(f"Best trial final validation accuracy: {best_result.metrics['accuracy']}")522523best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])524best_trained_model = best_trained_model.to(device)525if gpus_per_trial > 1:526best_trained_model = nn.DataParallel(best_trained_model)527528best_checkpoint = best_result.checkpoint529with best_checkpoint.as_directory() as checkpoint_dir:530checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"531best_checkpoint_data = torch.load(checkpoint_path)532533best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])534test_acc = test_accuracy(best_trained_model, device, data_dir)535print(f"Best trial test set accuracy: {test_acc}")536537538if __name__ == "__main__":539# Set the number of trials, epochs, and GPUs per trial here:540main(num_trials=10, max_num_epochs=10, gpus_per_trial=1)541542######################################################################543# Results544# =======545#546# Your Ray Tune trial summary output looks something like this. The text547# table summarizes the validation performance of the trials and highlights548# the best hyperparameter configuration:549#550# .. code-block:: bash551#552# Number of trials: 10/10 (10 TERMINATED)553# +-----+--------------+------+------+-------------+--------+---------+------------+554# | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |555# |-----+--------------+------+------+-------------+--------+---------+------------|556# | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |557# | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |558# | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |559# | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |560# | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |561# | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |562# | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |563# | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |564# | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |565# | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |566# +-----+--------------+------+------+-------------+--------+---------+------------+567#568# Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}569# Best trial final validation loss: 1.5310075663924216570# Best trial final validation accuracy: 0.4761571# Best trial test set accuracy: 0.4737572#573# Most trials stopped early to conserve resources. The best performing574# trial achieved a validation accuracy of approximately 47%, which the575# test set confirms.576#577# Observability578# =============579#580# Monitoring is critical when running large-scale experiments. Ray581# provides a582# `dashboard <https://docs.ray.io/en/latest/ray-observability/getting-started.html>`__583# that lets you view the status of your trials, check cluster resource584# use, and inspect logs in real time.585#586# For debugging, Ray also offers `distributed debugging587# tools <https://docs.ray.io/en/latest/ray-observability/index.html>`__588# that let you attach a debugger to running trials across the cluster.589#590# Conclusion591# ==========592#593# In this tutorial, you learned how to tune the hyperparameters of a594# PyTorch model using Ray Tune. You saw how to integrate Ray Tune into595# your PyTorch training loop, define a search space for your596# hyperparameters, use an efficient scheduler like ``ASHAScheduler`` to597# terminate low-performing trials early, save checkpoints and report598# metrics to Ray Tune, and run the hyperparameter search and analyze the599# results.600#601# Ray Tune makes it straightforward to scale your experiments from a602# single machine to a large cluster, helping you find the best model603# configuration efficiently.604#605# Further reading606# ===============607#608# - `Ray Tune609# documentation <https://docs.ray.io/en/latest/tune/index.html>`__610# - `Ray Tune611# examples <https://docs.ray.io/en/latest/tune/examples/index.html>`__612613614