Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/devices.py
3055 views
1
import sys
2
import contextlib
3
from functools import lru_cache
4
5
import torch
6
from modules import errors, shared, npu_specific
7
8
if sys.platform == "darwin":
9
from modules import mac_specific
10
11
if shared.cmd_opts.use_ipex:
12
from modules import xpu_specific
13
14
15
def has_xpu() -> bool:
16
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
17
18
19
def has_mps() -> bool:
20
if sys.platform != "darwin":
21
return False
22
else:
23
return mac_specific.has_mps
24
25
26
def cuda_no_autocast(device_id=None) -> bool:
27
if device_id is None:
28
device_id = get_cuda_device_id()
29
return (
30
torch.cuda.get_device_capability(device_id) == (7, 5)
31
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
32
)
33
34
35
def get_cuda_device_id():
36
return (
37
int(shared.cmd_opts.device_id)
38
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
39
else 0
40
) or torch.cuda.current_device()
41
42
43
def get_cuda_device_string():
44
if shared.cmd_opts.device_id is not None:
45
return f"cuda:{shared.cmd_opts.device_id}"
46
47
return "cuda"
48
49
50
def get_optimal_device_name():
51
if torch.cuda.is_available():
52
return get_cuda_device_string()
53
54
if has_mps():
55
return "mps"
56
57
if has_xpu():
58
return xpu_specific.get_xpu_device_string()
59
60
if npu_specific.has_npu:
61
return npu_specific.get_npu_device_string()
62
63
return "cpu"
64
65
66
def get_optimal_device():
67
return torch.device(get_optimal_device_name())
68
69
70
def get_device_for(task):
71
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
72
return cpu
73
74
return get_optimal_device()
75
76
77
def torch_gc():
78
79
if torch.cuda.is_available():
80
with torch.cuda.device(get_cuda_device_string()):
81
torch.cuda.empty_cache()
82
torch.cuda.ipc_collect()
83
84
if has_mps():
85
mac_specific.torch_mps_gc()
86
87
if has_xpu():
88
xpu_specific.torch_xpu_gc()
89
90
if npu_specific.has_npu:
91
torch_npu_set_device()
92
npu_specific.torch_npu_gc()
93
94
95
def torch_npu_set_device():
96
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
97
if npu_specific.has_npu:
98
torch.npu.set_device(0)
99
100
101
def enable_tf32():
102
if torch.cuda.is_available():
103
104
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
105
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
106
if cuda_no_autocast():
107
torch.backends.cudnn.benchmark = True
108
109
torch.backends.cuda.matmul.allow_tf32 = True
110
torch.backends.cudnn.allow_tf32 = True
111
112
113
errors.run(enable_tf32, "Enabling TF32")
114
115
cpu: torch.device = torch.device("cpu")
116
fp8: bool = False
117
# Force fp16 for all models in inference. No casting during inference.
118
# This flag is controlled by "--precision half" command line arg.
119
force_fp16: bool = False
120
device: torch.device = None
121
device_interrogate: torch.device = None
122
device_gfpgan: torch.device = None
123
device_esrgan: torch.device = None
124
device_codeformer: torch.device = None
125
dtype: torch.dtype = torch.float16
126
dtype_vae: torch.dtype = torch.float16
127
dtype_unet: torch.dtype = torch.float16
128
dtype_inference: torch.dtype = torch.float16
129
unet_needs_upcast = False
130
131
132
def cond_cast_unet(input):
133
if force_fp16:
134
return input.to(torch.float16)
135
return input.to(dtype_unet) if unet_needs_upcast else input
136
137
138
def cond_cast_float(input):
139
return input.float() if unet_needs_upcast else input
140
141
142
nv_rng = None
143
patch_module_list = [
144
torch.nn.Linear,
145
torch.nn.Conv2d,
146
torch.nn.MultiheadAttention,
147
torch.nn.GroupNorm,
148
torch.nn.LayerNorm,
149
]
150
151
152
def manual_cast_forward(target_dtype):
153
def forward_wrapper(self, *args, **kwargs):
154
if any(
155
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
156
for arg in args
157
):
158
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
159
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
160
161
org_dtype = target_dtype
162
for param in self.parameters():
163
if param.dtype != target_dtype:
164
org_dtype = param.dtype
165
break
166
167
if org_dtype != target_dtype:
168
self.to(target_dtype)
169
result = self.org_forward(*args, **kwargs)
170
if org_dtype != target_dtype:
171
self.to(org_dtype)
172
173
if target_dtype != dtype_inference:
174
if isinstance(result, tuple):
175
result = tuple(
176
i.to(dtype_inference)
177
if isinstance(i, torch.Tensor)
178
else i
179
for i in result
180
)
181
elif isinstance(result, torch.Tensor):
182
result = result.to(dtype_inference)
183
return result
184
return forward_wrapper
185
186
187
@contextlib.contextmanager
188
def manual_cast(target_dtype):
189
applied = False
190
for module_type in patch_module_list:
191
if hasattr(module_type, "org_forward"):
192
continue
193
applied = True
194
org_forward = module_type.forward
195
if module_type == torch.nn.MultiheadAttention:
196
module_type.forward = manual_cast_forward(torch.float32)
197
else:
198
module_type.forward = manual_cast_forward(target_dtype)
199
module_type.org_forward = org_forward
200
try:
201
yield None
202
finally:
203
if applied:
204
for module_type in patch_module_list:
205
if hasattr(module_type, "org_forward"):
206
module_type.forward = module_type.org_forward
207
delattr(module_type, "org_forward")
208
209
210
def autocast(disable=False):
211
if disable:
212
return contextlib.nullcontext()
213
214
if force_fp16:
215
# No casting during inference if force_fp16 is enabled.
216
# All tensor dtype conversion happens before inference.
217
return contextlib.nullcontext()
218
219
if fp8 and device==cpu:
220
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
221
222
if fp8 and dtype_inference == torch.float32:
223
return manual_cast(dtype)
224
225
if dtype == torch.float32 or dtype_inference == torch.float32:
226
return contextlib.nullcontext()
227
228
if has_xpu() or has_mps() or cuda_no_autocast():
229
return manual_cast(dtype)
230
231
return torch.autocast("cuda")
232
233
234
def without_autocast(disable=False):
235
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
236
237
238
class NansException(Exception):
239
pass
240
241
242
def test_for_nans(x, where):
243
if shared.cmd_opts.disable_nan_check:
244
return
245
246
if not torch.isnan(x[(0, ) * len(x.shape)]):
247
return
248
249
if where == "unet":
250
message = "A tensor with NaNs was produced in Unet."
251
252
if not shared.cmd_opts.no_half:
253
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
254
255
elif where == "vae":
256
message = "A tensor with NaNs was produced in VAE."
257
258
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
259
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
260
else:
261
message = "A tensor with NaNs was produced."
262
263
message += " Use --disable-nan-check commandline argument to disable this check."
264
265
raise NansException(message)
266
267
268
@lru_cache
269
def first_time_calculation():
270
"""
271
just do any calculation with pytorch layers - the first time this is done it allocates about 700MB of memory and
272
spends about 2.7 seconds doing that, at least with NVidia.
273
"""
274
275
x = torch.zeros((1, 1)).to(device, dtype)
276
linear = torch.nn.Linear(1, 1).to(device, dtype)
277
linear(x)
278
279
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
280
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
281
conv2d(x)
282
283
284
def force_model_fp16():
285
"""
286
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
287
force conversion of input to float32. If force_fp16 is enabled, we need to
288
prevent this casting.
289
"""
290
assert force_fp16
291
import sgm.modules.diffusionmodules.util as sgm_util
292
import ldm.modules.diffusionmodules.util as ldm_util
293
sgm_util.GroupNorm32 = torch.nn.GroupNorm
294
ldm_util.GroupNorm32 = torch.nn.GroupNorm
295
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")
296
297