Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/helpers/device.py
4918 views
1
import torch
2
3
from labml.configs import BaseConfigs, hyperparams, option
4
5
6
class DeviceInfo:
7
def __init__(self, *,
8
use_cuda: bool,
9
cuda_device: int):
10
self.use_cuda = use_cuda
11
self.cuda_device = cuda_device
12
self.cuda_count = torch.cuda.device_count()
13
14
self.is_cuda = self.use_cuda and torch.cuda.is_available()
15
if not self.is_cuda:
16
self.device = torch.device('cpu')
17
else:
18
if self.cuda_device < self.cuda_count:
19
self.device = torch.device('cuda', self.cuda_device)
20
else:
21
self.device = torch.device('cuda', self.cuda_count - 1)
22
23
def __str__(self):
24
if not self.is_cuda:
25
return "CPU"
26
27
if self.cuda_device < self.cuda_count:
28
return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}"
29
else:
30
return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) "
31
f"- {torch.cuda.get_device_name(self.cuda_count - 1)}")
32
33
34
class DeviceConfigs(BaseConfigs):
35
r"""
36
This is a configurable module to get a single device to train model on.
37
It can pick up CUDA devices and it will fall back to CPU if they are not available.
38
39
It has other small advantages such as being able to view the
40
actual device name on configurations view of
41
`labml app <https://github.com/labmlai/labml/tree/master/app>`_
42
43
Arguments:
44
cuda_device (int): The CUDA device number. Defaults to ``0``.
45
use_cuda (bool): Whether to use CUDA devices. Defaults to ``True``.
46
"""
47
cuda_device: int = 0
48
use_cuda: bool = True
49
50
device_info: DeviceInfo
51
52
device: torch.device
53
54
def __init__(self):
55
super().__init__(_primary='device')
56
57
58
@option(DeviceConfigs.device)
59
def _device(c: DeviceConfigs):
60
return c.device_info.device
61
62
63
hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda,
64
is_hyperparam=False)
65
66
67
@option(DeviceConfigs.device_info)
68
def _device_info(c: DeviceConfigs):
69
return DeviceInfo(use_cuda=c.use_cuda,
70
cuda_device=c.cuda_device)
71
72