Path: blob/master/modules/devices.py
3055 views
import sys1import contextlib2from functools import lru_cache34import torch5from modules import errors, shared, npu_specific67if sys.platform == "darwin":8from modules import mac_specific910if shared.cmd_opts.use_ipex:11from modules import xpu_specific121314def has_xpu() -> bool:15return shared.cmd_opts.use_ipex and xpu_specific.has_xpu161718def has_mps() -> bool:19if sys.platform != "darwin":20return False21else:22return mac_specific.has_mps232425def cuda_no_autocast(device_id=None) -> bool:26if device_id is None:27device_id = get_cuda_device_id()28return (29torch.cuda.get_device_capability(device_id) == (7, 5)30and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")31)323334def get_cuda_device_id():35return (36int(shared.cmd_opts.device_id)37if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()38else 039) or torch.cuda.current_device()404142def get_cuda_device_string():43if shared.cmd_opts.device_id is not None:44return f"cuda:{shared.cmd_opts.device_id}"4546return "cuda"474849def get_optimal_device_name():50if torch.cuda.is_available():51return get_cuda_device_string()5253if has_mps():54return "mps"5556if has_xpu():57return xpu_specific.get_xpu_device_string()5859if npu_specific.has_npu:60return npu_specific.get_npu_device_string()6162return "cpu"636465def get_optimal_device():66return torch.device(get_optimal_device_name())676869def get_device_for(task):70if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:71return cpu7273return get_optimal_device()747576def torch_gc():7778if torch.cuda.is_available():79with torch.cuda.device(get_cuda_device_string()):80torch.cuda.empty_cache()81torch.cuda.ipc_collect()8283if has_mps():84mac_specific.torch_mps_gc()8586if has_xpu():87xpu_specific.torch_xpu_gc()8889if npu_specific.has_npu:90torch_npu_set_device()91npu_specific.torch_npu_gc()929394def torch_npu_set_device():95# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue96if npu_specific.has_npu:97torch.npu.set_device(0)9899100def enable_tf32():101if torch.cuda.is_available():102103# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't104# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407105if cuda_no_autocast():106torch.backends.cudnn.benchmark = True107108torch.backends.cuda.matmul.allow_tf32 = True109torch.backends.cudnn.allow_tf32 = True110111112errors.run(enable_tf32, "Enabling TF32")113114cpu: torch.device = torch.device("cpu")115fp8: bool = False116# Force fp16 for all models in inference. No casting during inference.117# This flag is controlled by "--precision half" command line arg.118force_fp16: bool = False119device: torch.device = None120device_interrogate: torch.device = None121device_gfpgan: torch.device = None122device_esrgan: torch.device = None123device_codeformer: torch.device = None124dtype: torch.dtype = torch.float16125dtype_vae: torch.dtype = torch.float16126dtype_unet: torch.dtype = torch.float16127dtype_inference: torch.dtype = torch.float16128unet_needs_upcast = False129130131def cond_cast_unet(input):132if force_fp16:133return input.to(torch.float16)134return input.to(dtype_unet) if unet_needs_upcast else input135136137def cond_cast_float(input):138return input.float() if unet_needs_upcast else input139140141nv_rng = None142patch_module_list = [143torch.nn.Linear,144torch.nn.Conv2d,145torch.nn.MultiheadAttention,146torch.nn.GroupNorm,147torch.nn.LayerNorm,148]149150151def manual_cast_forward(target_dtype):152def forward_wrapper(self, *args, **kwargs):153if any(154isinstance(arg, torch.Tensor) and arg.dtype != target_dtype155for arg in args156):157args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]158kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}159160org_dtype = target_dtype161for param in self.parameters():162if param.dtype != target_dtype:163org_dtype = param.dtype164break165166if org_dtype != target_dtype:167self.to(target_dtype)168result = self.org_forward(*args, **kwargs)169if org_dtype != target_dtype:170self.to(org_dtype)171172if target_dtype != dtype_inference:173if isinstance(result, tuple):174result = tuple(175i.to(dtype_inference)176if isinstance(i, torch.Tensor)177else i178for i in result179)180elif isinstance(result, torch.Tensor):181result = result.to(dtype_inference)182return result183return forward_wrapper184185186@contextlib.contextmanager187def manual_cast(target_dtype):188applied = False189for module_type in patch_module_list:190if hasattr(module_type, "org_forward"):191continue192applied = True193org_forward = module_type.forward194if module_type == torch.nn.MultiheadAttention:195module_type.forward = manual_cast_forward(torch.float32)196else:197module_type.forward = manual_cast_forward(target_dtype)198module_type.org_forward = org_forward199try:200yield None201finally:202if applied:203for module_type in patch_module_list:204if hasattr(module_type, "org_forward"):205module_type.forward = module_type.org_forward206delattr(module_type, "org_forward")207208209def autocast(disable=False):210if disable:211return contextlib.nullcontext()212213if force_fp16:214# No casting during inference if force_fp16 is enabled.215# All tensor dtype conversion happens before inference.216return contextlib.nullcontext()217218if fp8 and device==cpu:219return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)220221if fp8 and dtype_inference == torch.float32:222return manual_cast(dtype)223224if dtype == torch.float32 or dtype_inference == torch.float32:225return contextlib.nullcontext()226227if has_xpu() or has_mps() or cuda_no_autocast():228return manual_cast(dtype)229230return torch.autocast("cuda")231232233def without_autocast(disable=False):234return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()235236237class NansException(Exception):238pass239240241def test_for_nans(x, where):242if shared.cmd_opts.disable_nan_check:243return244245if not torch.isnan(x[(0, ) * len(x.shape)]):246return247248if where == "unet":249message = "A tensor with NaNs was produced in Unet."250251if not shared.cmd_opts.no_half:252message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."253254elif where == "vae":255message = "A tensor with NaNs was produced in VAE."256257if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:258message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."259else:260message = "A tensor with NaNs was produced."261262message += " Use --disable-nan-check commandline argument to disable this check."263264raise NansException(message)265266267@lru_cache268def first_time_calculation():269"""270just do any calculation with pytorch layers - the first time this is done it allocates about 700MB of memory and271spends about 2.7 seconds doing that, at least with NVidia.272"""273274x = torch.zeros((1, 1)).to(device, dtype)275linear = torch.nn.Linear(1, 1).to(device, dtype)276linear(x)277278x = torch.zeros((1, 1, 3, 3)).to(device, dtype)279conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)280conv2d(x)281282283def force_model_fp16():284"""285ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which286force conversion of input to float32. If force_fp16 is enabled, we need to287prevent this casting.288"""289assert force_fp16290import sgm.modules.diffusionmodules.util as sgm_util291import ldm.modules.diffusionmodules.util as ldm_util292sgm_util.GroupNorm32 = torch.nn.GroupNorm293ldm_util.GroupNorm32 = torch.nn.GroupNorm294print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")295296297