Path: blob/master/modules/lowvram.py
3055 views
from collections import namedtuple12import torch3from modules import devices, shared45module_in_gpu = None6cpu = torch.device("cpu")78ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])910def send_everything_to_cpu():11global module_in_gpu1213if module_in_gpu is not None:14module_in_gpu.to(cpu)1516module_in_gpu = None171819def is_needed(sd_model):20return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')212223def apply(sd_model):24enable = is_needed(sd_model)25shared.parallel_processing_allowed = not enable2627if enable:28setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)29else:30sd_model.lowvram = False313233def setup_for_low_vram(sd_model, use_medvram):34if getattr(sd_model, 'lowvram', False):35return3637sd_model.lowvram = True3839parents = {}4041def send_me_to_gpu(module, _):42"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;43we add this as forward_pre_hook to a lot of modules and this way all but one of them will44be in CPU45"""46global module_in_gpu4748module = parents.get(module, module)4950if module_in_gpu == module:51return5253if module_in_gpu is not None:54module_in_gpu.to(cpu)5556module.to(devices.device)57module_in_gpu = module5859# see below for register_forward_pre_hook;60# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is61# useless here, and we just replace those methods6263first_stage_model = sd_model.first_stage_model64first_stage_model_encode = sd_model.first_stage_model.encode65first_stage_model_decode = sd_model.first_stage_model.decode6667def first_stage_model_encode_wrap(x):68send_me_to_gpu(first_stage_model, None)69return first_stage_model_encode(x)7071def first_stage_model_decode_wrap(z):72send_me_to_gpu(first_stage_model, None)73return first_stage_model_decode(z)7475to_remain_in_cpu = [76(sd_model, 'first_stage_model'),77(sd_model, 'depth_model'),78(sd_model, 'embedder'),79(sd_model, 'model'),80]8182is_sdxl = hasattr(sd_model, 'conditioner')83is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')8485if hasattr(sd_model, 'medvram_fields'):86to_remain_in_cpu = sd_model.medvram_fields()87elif is_sdxl:88to_remain_in_cpu.append((sd_model, 'conditioner'))89elif is_sd2:90to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))91else:92to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))9394# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model95stored = []96for obj, field in to_remain_in_cpu:97module = getattr(obj, field, None)98stored.append(module)99setattr(obj, field, None)100101# send the model to GPU.102sd_model.to(devices.device)103104# put modules back. the modules will be in CPU.105for (obj, field), module in zip(to_remain_in_cpu, stored):106setattr(obj, field, module)107108# register hooks for those the first three models109if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):110for module in sd_model.cond_stage_model.medvram_modules():111if isinstance(module, ModuleWithParent):112parent = module.parent113module = module.module114else:115parent = None116117if module:118module.register_forward_pre_hook(send_me_to_gpu)119120if parent:121parents[module] = parent122123elif is_sdxl:124sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)125elif is_sd2:126sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)127sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)128parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model129parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model130else:131sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)132parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model133134sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)135sd_model.first_stage_model.encode = first_stage_model_encode_wrap136sd_model.first_stage_model.decode = first_stage_model_decode_wrap137if getattr(sd_model, 'depth_model', None) is not None:138sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)139if getattr(sd_model, 'embedder', None) is not None:140sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)141142if use_medvram:143sd_model.model.register_forward_pre_hook(send_me_to_gpu)144else:145diff_model = sd_model.model.diffusion_model146147# the third remaining model is still too big for 4 GB, so we also do the same for its submodules148# so that only one of them is in GPU at a time149stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed150diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None151sd_model.model.to(devices.device)152diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored153154# install hooks for bits of third model155diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)156for block in diff_model.input_blocks:157block.register_forward_pre_hook(send_me_to_gpu)158diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)159for block in diff_model.output_blocks:160block.register_forward_pre_hook(send_me_to_gpu)161162163def is_enabled(sd_model):164return sd_model.lowvram165166167