Path: blob/master/extensions-builtin/hypertile/hypertile.py
2305 views
"""1Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE2Warn: The patch works well only if the input image has a width and height that are multiples of 1283Original author: @tfernd Github: https://github.com/tfernd/HyperTile4"""56from __future__ import annotations78from dataclasses import dataclass9from typing import Callable1011from functools import wraps, cache1213import math14import torch.nn as nn15import random1617from einops import rearrange181920@dataclass21class HypertileParams:22depth = 023layer_name = ""24tile_size: int = 025swap_size: int = 026aspect_ratio: float = 1.027forward = None28enabled = False29303132# TODO add SD-XL layers33DEPTH_LAYERS = {340: [35# SD 1.5 U-Net (diffusers)36"down_blocks.0.attentions.0.transformer_blocks.0.attn1",37"down_blocks.0.attentions.1.transformer_blocks.0.attn1",38"up_blocks.3.attentions.0.transformer_blocks.0.attn1",39"up_blocks.3.attentions.1.transformer_blocks.0.attn1",40"up_blocks.3.attentions.2.transformer_blocks.0.attn1",41# SD 1.5 U-Net (ldm)42"input_blocks.1.1.transformer_blocks.0.attn1",43"input_blocks.2.1.transformer_blocks.0.attn1",44"output_blocks.9.1.transformer_blocks.0.attn1",45"output_blocks.10.1.transformer_blocks.0.attn1",46"output_blocks.11.1.transformer_blocks.0.attn1",47# SD 1.5 VAE48"decoder.mid_block.attentions.0",49"decoder.mid.attn_1",50],511: [52# SD 1.5 U-Net (diffusers)53"down_blocks.1.attentions.0.transformer_blocks.0.attn1",54"down_blocks.1.attentions.1.transformer_blocks.0.attn1",55"up_blocks.2.attentions.0.transformer_blocks.0.attn1",56"up_blocks.2.attentions.1.transformer_blocks.0.attn1",57"up_blocks.2.attentions.2.transformer_blocks.0.attn1",58# SD 1.5 U-Net (ldm)59"input_blocks.4.1.transformer_blocks.0.attn1",60"input_blocks.5.1.transformer_blocks.0.attn1",61"output_blocks.6.1.transformer_blocks.0.attn1",62"output_blocks.7.1.transformer_blocks.0.attn1",63"output_blocks.8.1.transformer_blocks.0.attn1",64],652: [66# SD 1.5 U-Net (diffusers)67"down_blocks.2.attentions.0.transformer_blocks.0.attn1",68"down_blocks.2.attentions.1.transformer_blocks.0.attn1",69"up_blocks.1.attentions.0.transformer_blocks.0.attn1",70"up_blocks.1.attentions.1.transformer_blocks.0.attn1",71"up_blocks.1.attentions.2.transformer_blocks.0.attn1",72# SD 1.5 U-Net (ldm)73"input_blocks.7.1.transformer_blocks.0.attn1",74"input_blocks.8.1.transformer_blocks.0.attn1",75"output_blocks.3.1.transformer_blocks.0.attn1",76"output_blocks.4.1.transformer_blocks.0.attn1",77"output_blocks.5.1.transformer_blocks.0.attn1",78],793: [80# SD 1.5 U-Net (diffusers)81"mid_block.attentions.0.transformer_blocks.0.attn1",82# SD 1.5 U-Net (ldm)83"middle_block.1.transformer_blocks.0.attn1",84],85}86# XL layers, thanks for GitHub@gel-crabs for the help87DEPTH_LAYERS_XL = {880: [89# SD 1.5 U-Net (diffusers)90"down_blocks.0.attentions.0.transformer_blocks.0.attn1",91"down_blocks.0.attentions.1.transformer_blocks.0.attn1",92"up_blocks.3.attentions.0.transformer_blocks.0.attn1",93"up_blocks.3.attentions.1.transformer_blocks.0.attn1",94"up_blocks.3.attentions.2.transformer_blocks.0.attn1",95# SD 1.5 U-Net (ldm)96"input_blocks.4.1.transformer_blocks.0.attn1",97"input_blocks.5.1.transformer_blocks.0.attn1",98"output_blocks.3.1.transformer_blocks.0.attn1",99"output_blocks.4.1.transformer_blocks.0.attn1",100"output_blocks.5.1.transformer_blocks.0.attn1",101# SD 1.5 VAE102"decoder.mid_block.attentions.0",103"decoder.mid.attn_1",104],1051: [106# SD 1.5 U-Net (diffusers)107#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",108#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",109#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",110#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",111#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",112# SD 1.5 U-Net (ldm)113"input_blocks.4.1.transformer_blocks.1.attn1",114"input_blocks.5.1.transformer_blocks.1.attn1",115"output_blocks.3.1.transformer_blocks.1.attn1",116"output_blocks.4.1.transformer_blocks.1.attn1",117"output_blocks.5.1.transformer_blocks.1.attn1",118"input_blocks.7.1.transformer_blocks.0.attn1",119"input_blocks.8.1.transformer_blocks.0.attn1",120"output_blocks.0.1.transformer_blocks.0.attn1",121"output_blocks.1.1.transformer_blocks.0.attn1",122"output_blocks.2.1.transformer_blocks.0.attn1",123"input_blocks.7.1.transformer_blocks.1.attn1",124"input_blocks.8.1.transformer_blocks.1.attn1",125"output_blocks.0.1.transformer_blocks.1.attn1",126"output_blocks.1.1.transformer_blocks.1.attn1",127"output_blocks.2.1.transformer_blocks.1.attn1",128"input_blocks.7.1.transformer_blocks.2.attn1",129"input_blocks.8.1.transformer_blocks.2.attn1",130"output_blocks.0.1.transformer_blocks.2.attn1",131"output_blocks.1.1.transformer_blocks.2.attn1",132"output_blocks.2.1.transformer_blocks.2.attn1",133"input_blocks.7.1.transformer_blocks.3.attn1",134"input_blocks.8.1.transformer_blocks.3.attn1",135"output_blocks.0.1.transformer_blocks.3.attn1",136"output_blocks.1.1.transformer_blocks.3.attn1",137"output_blocks.2.1.transformer_blocks.3.attn1",138"input_blocks.7.1.transformer_blocks.4.attn1",139"input_blocks.8.1.transformer_blocks.4.attn1",140"output_blocks.0.1.transformer_blocks.4.attn1",141"output_blocks.1.1.transformer_blocks.4.attn1",142"output_blocks.2.1.transformer_blocks.4.attn1",143"input_blocks.7.1.transformer_blocks.5.attn1",144"input_blocks.8.1.transformer_blocks.5.attn1",145"output_blocks.0.1.transformer_blocks.5.attn1",146"output_blocks.1.1.transformer_blocks.5.attn1",147"output_blocks.2.1.transformer_blocks.5.attn1",148"input_blocks.7.1.transformer_blocks.6.attn1",149"input_blocks.8.1.transformer_blocks.6.attn1",150"output_blocks.0.1.transformer_blocks.6.attn1",151"output_blocks.1.1.transformer_blocks.6.attn1",152"output_blocks.2.1.transformer_blocks.6.attn1",153"input_blocks.7.1.transformer_blocks.7.attn1",154"input_blocks.8.1.transformer_blocks.7.attn1",155"output_blocks.0.1.transformer_blocks.7.attn1",156"output_blocks.1.1.transformer_blocks.7.attn1",157"output_blocks.2.1.transformer_blocks.7.attn1",158"input_blocks.7.1.transformer_blocks.8.attn1",159"input_blocks.8.1.transformer_blocks.8.attn1",160"output_blocks.0.1.transformer_blocks.8.attn1",161"output_blocks.1.1.transformer_blocks.8.attn1",162"output_blocks.2.1.transformer_blocks.8.attn1",163"input_blocks.7.1.transformer_blocks.9.attn1",164"input_blocks.8.1.transformer_blocks.9.attn1",165"output_blocks.0.1.transformer_blocks.9.attn1",166"output_blocks.1.1.transformer_blocks.9.attn1",167"output_blocks.2.1.transformer_blocks.9.attn1",168],1692: [170# SD 1.5 U-Net (diffusers)171"mid_block.attentions.0.transformer_blocks.0.attn1",172# SD 1.5 U-Net (ldm)173"middle_block.1.transformer_blocks.0.attn1",174"middle_block.1.transformer_blocks.1.attn1",175"middle_block.1.transformer_blocks.2.attn1",176"middle_block.1.transformer_blocks.3.attn1",177"middle_block.1.transformer_blocks.4.attn1",178"middle_block.1.transformer_blocks.5.attn1",179"middle_block.1.transformer_blocks.6.attn1",180"middle_block.1.transformer_blocks.7.attn1",181"middle_block.1.transformer_blocks.8.attn1",182"middle_block.1.transformer_blocks.9.attn1",183],1843 : [] # TODO - separate layers for SD-XL185}186187188RNG_INSTANCE = random.Random()189190@cache191def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:192"""193Returns divisors of value that194x * min_value <= value195in big -> small order, amount of divisors is limited by max_options196"""197max_options = max(1, max_options) # at least 1 option should be returned198min_value = min(min_value, value)199divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order200ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order201return ns202203204def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:205"""206Returns a random divisor of value that207x * min_value <= value208if max_options is 1, the behavior is deterministic209"""210ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors211idx = RNG_INSTANCE.randint(0, len(ns) - 1)212213return ns[idx]214215216def set_hypertile_seed(seed: int) -> None:217RNG_INSTANCE.seed(seed)218219220@cache221def largest_tile_size_available(width: int, height: int) -> int:222"""223Calculates the largest tile size available for a given width and height224Tile size is always a power of 2225"""226gcd = math.gcd(width, height)227largest_tile_size_available = 1228while gcd % (largest_tile_size_available * 2) == 0:229largest_tile_size_available *= 2230return largest_tile_size_available231232233def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:234"""235Finds h and w such that h*w = hw and h/w = aspect_ratio236We check all possible divisors of hw and return the closest to the aspect ratio237"""238divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw239pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw240ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw241closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio242closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio243return closest_pair244245246@cache247def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:248"""249Finds h and w such that h*w = hw and h/w = aspect_ratio250"""251h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))252# find h and w such that h*w = hw and h/w = aspect_ratio253if h * w != hw:254w_candidate = hw / h255# check if w is an integer256if not w_candidate.is_integer():257h_candidate = hw / w258# check if h is an integer259if not h_candidate.is_integer():260return iterative_closest_divisors(hw, aspect_ratio)261else:262h = int(h_candidate)263else:264w = int(w_candidate)265return h, w266267268def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:269270@wraps(params.forward)271def wrapper(*args, **kwargs):272if not params.enabled:273return params.forward(*args, **kwargs)274275latent_tile_size = max(128, params.tile_size) // 8276x = args[0]277278# VAE279if x.ndim == 4:280b, c, h, w = x.shape281282nh = random_divisor(h, latent_tile_size, params.swap_size)283nw = random_divisor(w, latent_tile_size, params.swap_size)284285if nh * nw > 1:286x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles287288out = params.forward(x, *args[1:], **kwargs)289290if nh * nw > 1:291out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)292293# U-Net294else:295hw: int = x.size(1)296h, w = find_hw_candidates(hw, params.aspect_ratio)297assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"298299factor = 2 ** params.depth if scale_depth else 1300nh = random_divisor(h, latent_tile_size * factor, params.swap_size)301nw = random_divisor(w, latent_tile_size * factor, params.swap_size)302303if nh * nw > 1:304x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)305306out = params.forward(x, *args[1:], **kwargs)307308if nh * nw > 1:309out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)310out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)311312return out313314return wrapper315316317def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):318hypertile_layers = getattr(model, "__webui_hypertile_layers", None)319if hypertile_layers is None:320if not enable:321return322323hypertile_layers = {}324layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS325326for depth in range(4):327for layer_name, module in model.named_modules():328if any(layer_name.endswith(try_name) for try_name in layers[depth]):329params = HypertileParams()330module.__webui_hypertile_params = params331params.forward = module.forward332params.depth = depth333params.layer_name = layer_name334module.forward = self_attn_forward(params)335336hypertile_layers[layer_name] = 1337338model.__webui_hypertile_layers = hypertile_layers339340aspect_ratio = width / height341tile_size = min(largest_tile_size_available(width, height), tile_size_max)342343for layer_name, module in model.named_modules():344if layer_name in hypertile_layers:345params = module.__webui_hypertile_params346347params.tile_size = tile_size348params.swap_size = swap_size349params.aspect_ratio = aspect_ratio350params.enabled = enable and params.depth <= max_depth351352353