Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AUTOMATIC1111
GitHub Repository: AUTOMATIC1111/stable-diffusion-webui
Path: blob/master/extensions-builtin/hypertile/hypertile.py
2305 views
1
"""
2
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
3
Warn: The patch works well only if the input image has a width and height that are multiples of 128
4
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
5
"""
6
7
from __future__ import annotations
8
9
from dataclasses import dataclass
10
from typing import Callable
11
12
from functools import wraps, cache
13
14
import math
15
import torch.nn as nn
16
import random
17
18
from einops import rearrange
19
20
21
@dataclass
22
class HypertileParams:
23
depth = 0
24
layer_name = ""
25
tile_size: int = 0
26
swap_size: int = 0
27
aspect_ratio: float = 1.0
28
forward = None
29
enabled = False
30
31
32
33
# TODO add SD-XL layers
34
DEPTH_LAYERS = {
35
0: [
36
# SD 1.5 U-Net (diffusers)
37
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
38
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
39
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
40
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
41
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
42
# SD 1.5 U-Net (ldm)
43
"input_blocks.1.1.transformer_blocks.0.attn1",
44
"input_blocks.2.1.transformer_blocks.0.attn1",
45
"output_blocks.9.1.transformer_blocks.0.attn1",
46
"output_blocks.10.1.transformer_blocks.0.attn1",
47
"output_blocks.11.1.transformer_blocks.0.attn1",
48
# SD 1.5 VAE
49
"decoder.mid_block.attentions.0",
50
"decoder.mid.attn_1",
51
],
52
1: [
53
# SD 1.5 U-Net (diffusers)
54
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
55
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
56
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
57
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
58
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
59
# SD 1.5 U-Net (ldm)
60
"input_blocks.4.1.transformer_blocks.0.attn1",
61
"input_blocks.5.1.transformer_blocks.0.attn1",
62
"output_blocks.6.1.transformer_blocks.0.attn1",
63
"output_blocks.7.1.transformer_blocks.0.attn1",
64
"output_blocks.8.1.transformer_blocks.0.attn1",
65
],
66
2: [
67
# SD 1.5 U-Net (diffusers)
68
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
69
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
70
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
71
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
72
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
73
# SD 1.5 U-Net (ldm)
74
"input_blocks.7.1.transformer_blocks.0.attn1",
75
"input_blocks.8.1.transformer_blocks.0.attn1",
76
"output_blocks.3.1.transformer_blocks.0.attn1",
77
"output_blocks.4.1.transformer_blocks.0.attn1",
78
"output_blocks.5.1.transformer_blocks.0.attn1",
79
],
80
3: [
81
# SD 1.5 U-Net (diffusers)
82
"mid_block.attentions.0.transformer_blocks.0.attn1",
83
# SD 1.5 U-Net (ldm)
84
"middle_block.1.transformer_blocks.0.attn1",
85
],
86
}
87
# XL layers, thanks for GitHub@gel-crabs for the help
88
DEPTH_LAYERS_XL = {
89
0: [
90
# SD 1.5 U-Net (diffusers)
91
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
92
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
93
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
94
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
95
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
96
# SD 1.5 U-Net (ldm)
97
"input_blocks.4.1.transformer_blocks.0.attn1",
98
"input_blocks.5.1.transformer_blocks.0.attn1",
99
"output_blocks.3.1.transformer_blocks.0.attn1",
100
"output_blocks.4.1.transformer_blocks.0.attn1",
101
"output_blocks.5.1.transformer_blocks.0.attn1",
102
# SD 1.5 VAE
103
"decoder.mid_block.attentions.0",
104
"decoder.mid.attn_1",
105
],
106
1: [
107
# SD 1.5 U-Net (diffusers)
108
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
109
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
110
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
111
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
112
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
113
# SD 1.5 U-Net (ldm)
114
"input_blocks.4.1.transformer_blocks.1.attn1",
115
"input_blocks.5.1.transformer_blocks.1.attn1",
116
"output_blocks.3.1.transformer_blocks.1.attn1",
117
"output_blocks.4.1.transformer_blocks.1.attn1",
118
"output_blocks.5.1.transformer_blocks.1.attn1",
119
"input_blocks.7.1.transformer_blocks.0.attn1",
120
"input_blocks.8.1.transformer_blocks.0.attn1",
121
"output_blocks.0.1.transformer_blocks.0.attn1",
122
"output_blocks.1.1.transformer_blocks.0.attn1",
123
"output_blocks.2.1.transformer_blocks.0.attn1",
124
"input_blocks.7.1.transformer_blocks.1.attn1",
125
"input_blocks.8.1.transformer_blocks.1.attn1",
126
"output_blocks.0.1.transformer_blocks.1.attn1",
127
"output_blocks.1.1.transformer_blocks.1.attn1",
128
"output_blocks.2.1.transformer_blocks.1.attn1",
129
"input_blocks.7.1.transformer_blocks.2.attn1",
130
"input_blocks.8.1.transformer_blocks.2.attn1",
131
"output_blocks.0.1.transformer_blocks.2.attn1",
132
"output_blocks.1.1.transformer_blocks.2.attn1",
133
"output_blocks.2.1.transformer_blocks.2.attn1",
134
"input_blocks.7.1.transformer_blocks.3.attn1",
135
"input_blocks.8.1.transformer_blocks.3.attn1",
136
"output_blocks.0.1.transformer_blocks.3.attn1",
137
"output_blocks.1.1.transformer_blocks.3.attn1",
138
"output_blocks.2.1.transformer_blocks.3.attn1",
139
"input_blocks.7.1.transformer_blocks.4.attn1",
140
"input_blocks.8.1.transformer_blocks.4.attn1",
141
"output_blocks.0.1.transformer_blocks.4.attn1",
142
"output_blocks.1.1.transformer_blocks.4.attn1",
143
"output_blocks.2.1.transformer_blocks.4.attn1",
144
"input_blocks.7.1.transformer_blocks.5.attn1",
145
"input_blocks.8.1.transformer_blocks.5.attn1",
146
"output_blocks.0.1.transformer_blocks.5.attn1",
147
"output_blocks.1.1.transformer_blocks.5.attn1",
148
"output_blocks.2.1.transformer_blocks.5.attn1",
149
"input_blocks.7.1.transformer_blocks.6.attn1",
150
"input_blocks.8.1.transformer_blocks.6.attn1",
151
"output_blocks.0.1.transformer_blocks.6.attn1",
152
"output_blocks.1.1.transformer_blocks.6.attn1",
153
"output_blocks.2.1.transformer_blocks.6.attn1",
154
"input_blocks.7.1.transformer_blocks.7.attn1",
155
"input_blocks.8.1.transformer_blocks.7.attn1",
156
"output_blocks.0.1.transformer_blocks.7.attn1",
157
"output_blocks.1.1.transformer_blocks.7.attn1",
158
"output_blocks.2.1.transformer_blocks.7.attn1",
159
"input_blocks.7.1.transformer_blocks.8.attn1",
160
"input_blocks.8.1.transformer_blocks.8.attn1",
161
"output_blocks.0.1.transformer_blocks.8.attn1",
162
"output_blocks.1.1.transformer_blocks.8.attn1",
163
"output_blocks.2.1.transformer_blocks.8.attn1",
164
"input_blocks.7.1.transformer_blocks.9.attn1",
165
"input_blocks.8.1.transformer_blocks.9.attn1",
166
"output_blocks.0.1.transformer_blocks.9.attn1",
167
"output_blocks.1.1.transformer_blocks.9.attn1",
168
"output_blocks.2.1.transformer_blocks.9.attn1",
169
],
170
2: [
171
# SD 1.5 U-Net (diffusers)
172
"mid_block.attentions.0.transformer_blocks.0.attn1",
173
# SD 1.5 U-Net (ldm)
174
"middle_block.1.transformer_blocks.0.attn1",
175
"middle_block.1.transformer_blocks.1.attn1",
176
"middle_block.1.transformer_blocks.2.attn1",
177
"middle_block.1.transformer_blocks.3.attn1",
178
"middle_block.1.transformer_blocks.4.attn1",
179
"middle_block.1.transformer_blocks.5.attn1",
180
"middle_block.1.transformer_blocks.6.attn1",
181
"middle_block.1.transformer_blocks.7.attn1",
182
"middle_block.1.transformer_blocks.8.attn1",
183
"middle_block.1.transformer_blocks.9.attn1",
184
],
185
3 : [] # TODO - separate layers for SD-XL
186
}
187
188
189
RNG_INSTANCE = random.Random()
190
191
@cache
192
def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
193
"""
194
Returns divisors of value that
195
x * min_value <= value
196
in big -> small order, amount of divisors is limited by max_options
197
"""
198
max_options = max(1, max_options) # at least 1 option should be returned
199
min_value = min(min_value, value)
200
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
201
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
202
return ns
203
204
205
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
206
"""
207
Returns a random divisor of value that
208
x * min_value <= value
209
if max_options is 1, the behavior is deterministic
210
"""
211
ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
212
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
213
214
return ns[idx]
215
216
217
def set_hypertile_seed(seed: int) -> None:
218
RNG_INSTANCE.seed(seed)
219
220
221
@cache
222
def largest_tile_size_available(width: int, height: int) -> int:
223
"""
224
Calculates the largest tile size available for a given width and height
225
Tile size is always a power of 2
226
"""
227
gcd = math.gcd(width, height)
228
largest_tile_size_available = 1
229
while gcd % (largest_tile_size_available * 2) == 0:
230
largest_tile_size_available *= 2
231
return largest_tile_size_available
232
233
234
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
235
"""
236
Finds h and w such that h*w = hw and h/w = aspect_ratio
237
We check all possible divisors of hw and return the closest to the aspect ratio
238
"""
239
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
240
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
241
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
242
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
243
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
244
return closest_pair
245
246
247
@cache
248
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
249
"""
250
Finds h and w such that h*w = hw and h/w = aspect_ratio
251
"""
252
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
253
# find h and w such that h*w = hw and h/w = aspect_ratio
254
if h * w != hw:
255
w_candidate = hw / h
256
# check if w is an integer
257
if not w_candidate.is_integer():
258
h_candidate = hw / w
259
# check if h is an integer
260
if not h_candidate.is_integer():
261
return iterative_closest_divisors(hw, aspect_ratio)
262
else:
263
h = int(h_candidate)
264
else:
265
w = int(w_candidate)
266
return h, w
267
268
269
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
270
271
@wraps(params.forward)
272
def wrapper(*args, **kwargs):
273
if not params.enabled:
274
return params.forward(*args, **kwargs)
275
276
latent_tile_size = max(128, params.tile_size) // 8
277
x = args[0]
278
279
# VAE
280
if x.ndim == 4:
281
b, c, h, w = x.shape
282
283
nh = random_divisor(h, latent_tile_size, params.swap_size)
284
nw = random_divisor(w, latent_tile_size, params.swap_size)
285
286
if nh * nw > 1:
287
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
288
289
out = params.forward(x, *args[1:], **kwargs)
290
291
if nh * nw > 1:
292
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
293
294
# U-Net
295
else:
296
hw: int = x.size(1)
297
h, w = find_hw_candidates(hw, params.aspect_ratio)
298
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
299
300
factor = 2 ** params.depth if scale_depth else 1
301
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
302
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
303
304
if nh * nw > 1:
305
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
306
307
out = params.forward(x, *args[1:], **kwargs)
308
309
if nh * nw > 1:
310
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
311
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
312
313
return out
314
315
return wrapper
316
317
318
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
319
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
320
if hypertile_layers is None:
321
if not enable:
322
return
323
324
hypertile_layers = {}
325
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
326
327
for depth in range(4):
328
for layer_name, module in model.named_modules():
329
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
330
params = HypertileParams()
331
module.__webui_hypertile_params = params
332
params.forward = module.forward
333
params.depth = depth
334
params.layer_name = layer_name
335
module.forward = self_attn_forward(params)
336
337
hypertile_layers[layer_name] = 1
338
339
model.__webui_hypertile_layers = hypertile_layers
340
341
aspect_ratio = width / height
342
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
343
344
for layer_name, module in model.named_modules():
345
if layer_name in hypertile_layers:
346
params = module.__webui_hypertile_params
347
348
params.tile_size = tile_size
349
params.swap_size = swap_size
350
params.aspect_ratio = aspect_ratio
351
params.enabled = enable and params.depth <= max_depth
352
353