Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/hyperparameter_tuning_tutorial.py
3697 views
1
"""
2
Hyperparameter tuning using Ray Tune
3
====================================
4
5
**Author:** `Ricardo Decal <https://github.com/crypdick>`__
6
7
This tutorial shows how to integrate Ray Tune into your PyTorch training
8
workflow to perform scalable and efficient hyperparameter tuning.
9
10
.. grid:: 2
11
12
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
13
:class-card: card-prerequisites
14
15
* How to modify a PyTorch training loop for Ray Tune
16
* How to scale a hyperparameter sweep to multiple nodes and GPUs without code changes
17
* How to define a hyperparameter search space and run a sweep with ``tune.Tuner``
18
* How to use an early-stopping scheduler (ASHA) and report metrics/checkpoints
19
* How to use checkpointing to resume training and load the best model
20
21
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
22
:class-card: card-prerequisites
23
24
* PyTorch v2.9+ and ``torchvision``
25
* Ray Tune (``ray[tune]``) v2.52.1+
26
* GPU(s) are optional, but recommended for faster training
27
28
`Ray <https://docs.ray.io/en/latest/index.html>`__, a project of the
29
PyTorch Foundation, is an open source unified framework for scaling AI
30
and Python applications. It helps run distributed jobs by handling the
31
complexity of distributed computing. `Ray
32
Tune <https://docs.ray.io/en/latest/tune/index.html>`__ is a library
33
built on Ray for hyperparameter tuning that enables you to scale a
34
hyperparameter sweep from your machine to a large cluster with no code
35
changes.
36
37
This tutorial adapts the `PyTorch tutorial for training a CIFAR10
38
classifier <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`__
39
to run multi-GPU hyperparameter sweeps with Ray Tune.
40
41
Setup
42
-----
43
44
To run this tutorial, install the following dependencies:
45
46
.. code-block:: bash
47
48
pip install "ray[tune]" torchvision
49
50
"""
51
52
######################################################################
53
# Then start with the imports:
54
55
from functools import partial
56
import os
57
import tempfile
58
from pathlib import Path
59
import torch
60
import torch.nn as nn
61
import torch.nn.functional as F
62
import torch.optim as optim
63
from torch.utils.data import random_split
64
import torchvision
65
import torchvision.transforms as transforms
66
# New: imports for Ray Tune
67
import ray
68
from ray import tune
69
from ray.tune import Checkpoint
70
from ray.tune.schedulers import ASHAScheduler
71
72
######################################################################
73
# Data loading
74
# ============
75
#
76
# Wrap the data loaders in a constructor function. In this tutorial, a
77
# global data directory is passed to the function to enable reusing the
78
# dataset across different trials. In a cluster environment, you can use
79
# shared storage, such as network file systems, to prevent each node from
80
# downloading the data separately.
81
82
def load_data(data_dir="./data"):
83
# Mean and standard deviation of the CIFAR10 training subset.
84
transform = transforms.Compose(
85
[transforms.ToTensor(), transforms.Normalize((0.4914, 0.48216, 0.44653), (0.2022, 0.19932, 0.20086))]
86
)
87
88
trainset = torchvision.datasets.CIFAR10(
89
root=data_dir, train=True, download=True, transform=transform
90
)
91
92
testset = torchvision.datasets.CIFAR10(
93
root=data_dir, train=False, download=True, transform=transform
94
)
95
96
return trainset, testset
97
98
######################################################################
99
# Model architecture
100
# ==================
101
#
102
# This tutorial searches for the best sizes for the fully connected layers
103
# and the learning rate. To enable this, the ``Net`` class exposes the
104
# layer sizes ``l1`` and ``l2`` as configurable parameters that Ray Tune
105
# can search over:
106
107
class Net(nn.Module):
108
def __init__(self, l1=120, l2=84):
109
super().__init__()
110
self.conv1 = nn.Conv2d(3, 6, 5)
111
self.pool = nn.MaxPool2d(2, 2)
112
self.conv2 = nn.Conv2d(6, 16, 5)
113
self.fc1 = nn.Linear(16 * 5 * 5, l1)
114
self.fc2 = nn.Linear(l1, l2)
115
self.fc3 = nn.Linear(l2, 10)
116
117
def forward(self, x):
118
x = self.pool(F.relu(self.conv1(x)))
119
x = self.pool(F.relu(self.conv2(x)))
120
x = torch.flatten(x, 1) # flatten all dimensions except batch
121
x = F.relu(self.fc1(x))
122
x = F.relu(self.fc2(x))
123
x = self.fc3(x)
124
return x
125
126
######################################################################
127
# Define the search space
128
# =======================
129
#
130
# Next, define the hyperparameters to tune and how Ray Tune samples them.
131
# Ray Tune offers a variety of `search space
132
# distributions <https://docs.ray.io/en/latest/tune/api/search_space.html>`__
133
# to suit different parameter types: ``loguniform``, ``uniform``,
134
# ``choice``, ``randint``, ``grid``, and more. You can also express
135
# complex dependencies between parameters with `conditional search
136
# spaces <https://docs.ray.io/en/latest/tune/tutorials/tune-search-spaces.html#how-to-use-custom-and-conditional-search-spaces-in-tune>`__
137
# or sample from arbitrary functions.
138
#
139
# Here is the search space for this tutorial:
140
#
141
# .. code-block:: python
142
#
143
# config = {
144
# "l1": tune.choice([2**i for i in range(9)]),
145
# "l2": tune.choice([2**i for i in range(9)]),
146
# "lr": tune.loguniform(1e-4, 1e-1),
147
# "batch_size": tune.choice([2, 4, 8, 16]),
148
# }
149
#
150
# The ``tune.choice()`` accepts a list of values that are uniformly
151
# sampled from. In this example, the ``l1`` and ``l2`` parameter values
152
# are powers of 2 between 1 and 256, and the learning rate samples on a
153
# log scale between 0.0001 and 0.1. Sampling on a log scale enables
154
# exploration across a range of magnitudes on a relative scale, rather
155
# than an absolute scale.
156
#
157
# Training function
158
# =================
159
#
160
# Ray Tune requires a training function that accepts a configuration
161
# dictionary and runs the main training loop. As Ray Tune runs different
162
# trials, it updates the configuration dictionary for each trial.
163
#
164
# Here is the full training function, followed by explanations of the key
165
# Ray Tune integration points:
166
167
def train_cifar(config, data_dir=None):
168
net = Net(config["l1"], config["l2"])
169
device = config["device"]
170
171
net = net.to(device)
172
if torch.cuda.device_count() > 1:
173
net = nn.DataParallel(net)
174
175
criterion = nn.CrossEntropyLoss()
176
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
177
178
# Load checkpoint if resuming training
179
checkpoint = tune.get_checkpoint()
180
if checkpoint:
181
with checkpoint.as_directory() as checkpoint_dir:
182
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
183
checkpoint_state = torch.load(checkpoint_path)
184
start_epoch = checkpoint_state["epoch"]
185
net.load_state_dict(checkpoint_state["net_state_dict"])
186
optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
187
else:
188
start_epoch = 0
189
190
trainset, _testset = load_data(data_dir)
191
192
test_abs = int(len(trainset) * 0.8)
193
train_subset, val_subset = random_split(
194
trainset, [test_abs, len(trainset) - test_abs]
195
)
196
197
trainloader = torch.utils.data.DataLoader(
198
train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
199
)
200
valloader = torch.utils.data.DataLoader(
201
val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
202
)
203
204
for epoch in range(start_epoch, 10): # loop over the dataset multiple times
205
running_loss = 0.0
206
epoch_steps = 0
207
for i, data in enumerate(trainloader, 0):
208
# get the inputs; data is a list of [inputs, labels]
209
inputs, labels = data
210
inputs, labels = inputs.to(device), labels.to(device)
211
212
# zero the parameter gradients
213
optimizer.zero_grad()
214
215
# forward + backward + optimize
216
outputs = net(inputs)
217
loss = criterion(outputs, labels)
218
loss.backward()
219
optimizer.step()
220
221
# print statistics
222
running_loss += loss.item()
223
epoch_steps += 1
224
if i % 2000 == 1999: # print every 2000 mini-batches
225
print(
226
"[%d, %5d] loss: %.3f"
227
% (epoch + 1, i + 1, running_loss / epoch_steps)
228
)
229
running_loss = 0.0
230
231
# Validation loss
232
val_loss = 0.0
233
val_steps = 0
234
total = 0
235
correct = 0
236
for i, data in enumerate(valloader, 0):
237
with torch.no_grad():
238
inputs, labels = data
239
inputs, labels = inputs.to(device), labels.to(device)
240
241
outputs = net(inputs)
242
_, predicted = torch.max(outputs.data, 1)
243
total += labels.size(0)
244
correct += (predicted == labels).sum().item()
245
246
loss = criterion(outputs, labels)
247
val_loss += loss.cpu().numpy()
248
val_steps += 1
249
250
# Save checkpoint and report metrics
251
checkpoint_data = {
252
"epoch": epoch,
253
"net_state_dict": net.state_dict(),
254
"optimizer_state_dict": optimizer.state_dict(),
255
}
256
with tempfile.TemporaryDirectory() as checkpoint_dir:
257
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
258
torch.save(checkpoint_data, checkpoint_path)
259
260
checkpoint = Checkpoint.from_directory(checkpoint_dir)
261
tune.report(
262
{"loss": val_loss / val_steps, "accuracy": correct / total},
263
checkpoint=checkpoint,
264
)
265
266
print("Finished Training")
267
268
######################################################################
269
# Key integration points
270
# ----------------------
271
#
272
# Using hyperparameters from the configuration dictionary
273
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274
#
275
# Ray Tune updates the ``config`` dictionary with the hyperparameters for
276
# each trial. In this example, the model architecture and optimizer
277
# receive the hyperparameters from the ``config`` dictionary:
278
#
279
# .. code-block:: python
280
#
281
# net = Net(config["l1"], config["l2"])
282
# optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
283
#
284
# Reporting metrics and saving checkpoints
285
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
286
#
287
# The most important integration is communicating with Ray Tune. Ray Tune
288
# uses the validation metrics to determine the best hyperparameter
289
# configuration and to stop underperforming trials early, saving
290
# resources.
291
#
292
# Checkpointing enables you to later load the trained models, resume
293
# hyperparameter searches, and provides fault tolerance. It’s also
294
# required for some Ray Tune schedulers like `Population Based
295
# Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`__
296
# that pause and resume trials during the search.
297
#
298
# This code from the training function loads model and optimizer state at
299
# the start if a checkpoint exists:
300
#
301
# .. code-block:: python
302
#
303
# checkpoint = tune.get_checkpoint()
304
# if checkpoint:
305
# with checkpoint.as_directory() as checkpoint_dir:
306
# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
307
# checkpoint_state = torch.load(checkpoint_path)
308
# start_epoch = checkpoint_state["epoch"]
309
# net.load_state_dict(checkpoint_state["net_state_dict"])
310
# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
311
#
312
# At the end of each epoch, save a checkpoint and report the validation
313
# metrics:
314
#
315
# .. code-block:: python
316
#
317
# checkpoint_data = {
318
# "epoch": epoch,
319
# "net_state_dict": net.state_dict(),
320
# "optimizer_state_dict": optimizer.state_dict(),
321
# }
322
# with tempfile.TemporaryDirectory() as checkpoint_dir:
323
# checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
324
# torch.save(checkpoint_data, checkpoint_path)
325
#
326
# checkpoint = Checkpoint.from_directory(checkpoint_dir)
327
# tune.report(
328
# {"loss": val_loss / val_steps, "accuracy": correct / total},
329
# checkpoint=checkpoint,
330
# )
331
#
332
# Ray Tune checkpointing supports local file systems, cloud storage, and
333
# distributed file systems. For more information, see the `Ray Tune
334
# storage
335
# documentation <https://docs.ray.io/en/latest/tune/tutorials/tune-storage.html>`__.
336
#
337
# Multi-GPU support
338
# ~~~~~~~~~~~~~~~~~
339
#
340
# Image classification models can be greatly accelerated by using GPUs.
341
# The training function supports multi-GPU training by wrapping the model
342
# in ``nn.DataParallel``:
343
#
344
# .. code-block:: python
345
#
346
# if torch.cuda.device_count() > 1:
347
# net = nn.DataParallel(net)
348
#
349
# This training function supports training on CPUs, a single GPU, multiple GPUs, or
350
# multiple nodes without code changes. Ray Tune automatically distributes the trials
351
# across the nodes according to the available resources. Ray Tune also supports `fractional
352
# GPUs <https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html#fractional-accelerators>`__
353
# so that one GPU can be shared among multiple trials, provided that the
354
# models, optimizers, and data batches fit into the GPU memory.
355
#
356
# Validation split
357
# ~~~~~~~~~~~~~~~~
358
#
359
# The original CIFAR10 dataset only has train and test subsets. This is
360
# sufficient for training a single model, however for hyperparameter
361
# tuning a validation subset is required. The training function creates a
362
# validation subset by reserving 20% of the training subset. The test
363
# subset is used to evaluate the best model’s generalization error after
364
# the search completes.
365
#
366
# Evaluation function
367
# ===================
368
#
369
# After finding the optimal hyperparameters, test the model on a held-out
370
# test set to estimate the generalization error:
371
372
def test_accuracy(net, device="cpu", data_dir=None):
373
_trainset, testset = load_data(data_dir)
374
375
testloader = torch.utils.data.DataLoader(
376
testset, batch_size=4, shuffle=False, num_workers=2
377
)
378
379
correct = 0
380
total = 0
381
with torch.no_grad():
382
for data in testloader:
383
image_batch, labels = data
384
image_batch, labels = image_batch.to(device), labels.to(device)
385
outputs = net(image_batch)
386
_, predicted = torch.max(outputs.data, 1)
387
total += labels.size(0)
388
correct += (predicted == labels).sum().item()
389
390
return correct / total
391
392
######################################################################
393
# Configure and run Ray Tune
394
# ==========================
395
#
396
# With the training and evaluation functions defined, configure Ray Tune
397
# to run the hyperparameter search.
398
#
399
# Scheduler for early stopping
400
# ----------------------------
401
#
402
# Ray Tune provides schedulers to improve the efficiency of the
403
# hyperparameter search by detecting underperforming trials and stopping
404
# them early. The ``ASHAScheduler`` uses the Asynchronous Successive
405
# Halving Algorithm (ASHA) to aggressively terminate low-performing
406
# trials:
407
#
408
# .. code-block:: python
409
#
410
# scheduler = ASHAScheduler(
411
# max_t=max_num_epochs,
412
# grace_period=1,
413
# reduction_factor=2,
414
# )
415
#
416
# Ray Tune also provides `advanced search
417
# algorithms <https://docs.ray.io/en/latest/tune/api/suggestion.html>`__
418
# to smartly pick the next set of hyperparameters based on previous
419
# results, instead of relying only on random or grid search. Examples
420
# include
421
# `Optuna <https://docs.ray.io/en/latest/tune/api/suggestion.html#optuna>`__
422
# and
423
# `BayesOpt <https://docs.ray.io/en/latest/tune/api/suggestion.html#bayesopt>`__.
424
#
425
# Resource allocation
426
# -------------------
427
#
428
# Tell Ray Tune what resources to allocate for each trial by passing a
429
# ``resources`` dictionary to ``tune.with_resources``:
430
#
431
# .. code-block:: python
432
#
433
# tune.with_resources(
434
# partial(train_cifar, data_dir=data_dir),
435
# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
436
# )
437
#
438
# Ray Tune automatically manages the placement of these trials and ensures
439
# that the trials run in isolation, so you don’t need to manually assign
440
# GPUs to processes.
441
#
442
# For example, if you are running this experiment on a cluster of 20
443
# machines, each with 8 GPUs, you can set ``gpus_per_trial = 0.5`` to
444
# schedule two concurrent trials per GPU. This configuration runs 320
445
# trials in parallel across the cluster.
446
#
447
# .. note::
448
# To run this tutorial without GPUs, set ``gpus_per_trial=0``
449
# and expect significantly longer runtimes.
450
#
451
# To avoid long runtimes during development, start with a small number
452
# of trials and epochs.
453
#
454
# Creating the Tuner
455
# ------------------
456
#
457
# The Ray Tune API is modular and composable. Pass your configuration to
458
# the ``tune.Tuner`` class to create a tuner object, then run
459
# ``tuner.fit()`` to start training:
460
#
461
# .. code-block:: python
462
#
463
# tuner = tune.Tuner(
464
# tune.with_resources(
465
# partial(train_cifar, data_dir=data_dir),
466
# resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
467
# ),
468
# tune_config=tune.TuneConfig(
469
# metric="loss",
470
# mode="min",
471
# scheduler=scheduler,
472
# num_samples=num_trials,
473
# ),
474
# param_space=config,
475
# )
476
# results = tuner.fit()
477
#
478
# After training completes, retrieve the best performing trial, load its
479
# checkpoint, and evaluate on the test set.
480
#
481
# Putting it all together
482
# -----------------------
483
484
def main(num_trials=10, max_num_epochs=10, gpus_per_trial=0, cpus_per_trial=2):
485
print("Starting hyperparameter tuning.")
486
ray.init(include_dashboard=False)
487
488
data_dir = os.path.abspath("./data")
489
load_data(data_dir) # Pre-download the dataset
490
device = "cuda" if torch.cuda.is_available() else "cpu"
491
config = {
492
"l1": tune.choice([2**i for i in range(9)]),
493
"l2": tune.choice([2**i for i in range(9)]),
494
"lr": tune.loguniform(1e-4, 1e-1),
495
"batch_size": tune.choice([2, 4, 8, 16]),
496
"device": device,
497
}
498
scheduler = ASHAScheduler(
499
max_t=max_num_epochs,
500
grace_period=1,
501
reduction_factor=2,
502
)
503
504
tuner = tune.Tuner(
505
tune.with_resources(
506
partial(train_cifar, data_dir=data_dir),
507
resources={"cpu": cpus_per_trial, "gpu": gpus_per_trial}
508
),
509
tune_config=tune.TuneConfig(
510
metric="loss",
511
mode="min",
512
scheduler=scheduler,
513
num_samples=num_trials,
514
),
515
param_space=config,
516
)
517
results = tuner.fit()
518
519
best_result = results.get_best_result("loss", "min")
520
print(f"Best trial config: {best_result.config}")
521
print(f"Best trial final validation loss: {best_result.metrics['loss']}")
522
print(f"Best trial final validation accuracy: {best_result.metrics['accuracy']}")
523
524
best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
525
best_trained_model = best_trained_model.to(device)
526
if gpus_per_trial > 1:
527
best_trained_model = nn.DataParallel(best_trained_model)
528
529
best_checkpoint = best_result.checkpoint
530
with best_checkpoint.as_directory() as checkpoint_dir:
531
checkpoint_path = Path(checkpoint_dir) / "checkpoint.pt"
532
best_checkpoint_data = torch.load(checkpoint_path)
533
534
best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
535
test_acc = test_accuracy(best_trained_model, device, data_dir)
536
print(f"Best trial test set accuracy: {test_acc}")
537
538
539
if __name__ == "__main__":
540
# Set the number of trials, epochs, and GPUs per trial here:
541
main(num_trials=10, max_num_epochs=10, gpus_per_trial=1)
542
543
######################################################################
544
# Results
545
# =======
546
#
547
# Your Ray Tune trial summary output looks something like this. The text
548
# table summarizes the validation performance of the trials and highlights
549
# the best hyperparameter configuration:
550
#
551
# .. code-block:: bash
552
#
553
# Number of trials: 10/10 (10 TERMINATED)
554
# +-----+--------------+------+------+-------------+--------+---------+------------+
555
# | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |
556
# |-----+--------------+------+------+-------------+--------+---------+------------|
557
# | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |
558
# | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |
559
# | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |
560
# | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |
561
# | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |
562
# | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |
563
# | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |
564
# | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |
565
# | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |
566
# | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |
567
# +-----+--------------+------+------+-------------+--------+---------+------------+
568
#
569
# Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
570
# Best trial final validation loss: 1.5310075663924216
571
# Best trial final validation accuracy: 0.4761
572
# Best trial test set accuracy: 0.4737
573
#
574
# Most trials stopped early to conserve resources. The best performing
575
# trial achieved a validation accuracy of approximately 47%, which the
576
# test set confirms.
577
#
578
# Observability
579
# =============
580
#
581
# Monitoring is critical when running large-scale experiments. Ray
582
# provides a
583
# `dashboard <https://docs.ray.io/en/latest/ray-observability/getting-started.html>`__
584
# that lets you view the status of your trials, check cluster resource
585
# use, and inspect logs in real time.
586
#
587
# For debugging, Ray also offers `distributed debugging
588
# tools <https://docs.ray.io/en/latest/ray-observability/index.html>`__
589
# that let you attach a debugger to running trials across the cluster.
590
#
591
# Conclusion
592
# ==========
593
#
594
# In this tutorial, you learned how to tune the hyperparameters of a
595
# PyTorch model using Ray Tune. You saw how to integrate Ray Tune into
596
# your PyTorch training loop, define a search space for your
597
# hyperparameters, use an efficient scheduler like ``ASHAScheduler`` to
598
# terminate low-performing trials early, save checkpoints and report
599
# metrics to Ray Tune, and run the hyperparameter search and analyze the
600
# results.
601
#
602
# Ray Tune makes it straightforward to scale your experiments from a
603
# single machine to a large cluster, helping you find the best model
604
# configuration efficiently.
605
#
606
# Further reading
607
# ===============
608
#
609
# - `Ray Tune
610
# documentation <https://docs.ray.io/en/latest/tune/index.html>`__
611
# - `Ray Tune
612
# examples <https://docs.ray.io/en/latest/tune/examples/index.html>`__
613
614