Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/community/tiled_upscaling.py
1448 views
1
# Copyright 2023 Peter Willemsen <[email protected]>. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import math
16
from typing import Callable, List, Optional, Union
17
18
import numpy as np
19
import PIL
20
import torch
21
from PIL import Image
22
from transformers import CLIPTextModel, CLIPTokenizer
23
24
from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
26
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
27
28
29
def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
30
size_x = size[0] - overlap_pixels * 2
31
size_y = size[1] - overlap_pixels * 2
32
for letter in ["l", "r"]:
33
if letter in remove_borders:
34
size_x += overlap_pixels
35
for letter in ["t", "b"]:
36
if letter in remove_borders:
37
size_y += overlap_pixels
38
mask = np.ones((size_y, size_x), dtype=np.uint8) * 255
39
mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)
40
41
if "l" in remove_borders:
42
mask = mask[:, overlap_pixels : mask.shape[1]]
43
if "r" in remove_borders:
44
mask = mask[:, 0 : mask.shape[1] - overlap_pixels]
45
if "t" in remove_borders:
46
mask = mask[overlap_pixels : mask.shape[0], :]
47
if "b" in remove_borders:
48
mask = mask[0 : mask.shape[0] - overlap_pixels, :]
49
return mask
50
51
52
def clamp(n, smallest, largest):
53
return max(smallest, min(n, largest))
54
55
56
def clamp_rect(rect: [int], min: [int], max: [int]):
57
return (
58
clamp(rect[0], min[0], max[0]),
59
clamp(rect[1], min[1], max[1]),
60
clamp(rect[2], min[0], max[0]),
61
clamp(rect[3], min[1], max[1]),
62
)
63
64
65
def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):
66
rect = list(rect)
67
rect[0] -= overlap
68
rect[1] -= overlap
69
rect[2] += overlap
70
rect[3] += overlap
71
rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])
72
return rect
73
74
75
def squeeze_tile(tile, original_image, original_slice, slice_x):
76
result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))
77
result.paste(
78
original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(
79
(slice_x, 0, slice_x + original_slice, tile.size[1])
80
),
81
(0, 0),
82
)
83
result.paste(tile, (original_slice, 0))
84
return result
85
86
87
def unsqueeze_tile(tile, original_image_slice):
88
crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])
89
tile = tile.crop(crop_rect)
90
return tile
91
92
93
def next_divisible(n, d):
94
divisor = n % d
95
return n - divisor
96
97
98
class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
99
r"""
100
Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute
101
to create gigantic images.
102
103
This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the
104
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
105
106
Args:
107
vae ([`AutoencoderKL`]):
108
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
109
text_encoder ([`CLIPTextModel`]):
110
Frozen text-encoder. Stable Diffusion uses the text portion of
111
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
112
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
113
tokenizer (`CLIPTokenizer`):
114
Tokenizer of class
115
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
116
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
117
low_res_scheduler ([`SchedulerMixin`]):
118
A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
119
[`DDPMScheduler`].
120
scheduler ([`SchedulerMixin`]):
121
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
122
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
123
"""
124
125
def __init__(
126
self,
127
vae: AutoencoderKL,
128
text_encoder: CLIPTextModel,
129
tokenizer: CLIPTokenizer,
130
unet: UNet2DConditionModel,
131
low_res_scheduler: DDPMScheduler,
132
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
133
max_noise_level: int = 350,
134
):
135
super().__init__(
136
vae=vae,
137
text_encoder=text_encoder,
138
tokenizer=tokenizer,
139
unet=unet,
140
low_res_scheduler=low_res_scheduler,
141
scheduler=scheduler,
142
max_noise_level=max_noise_level,
143
)
144
145
def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):
146
torch.manual_seed(0)
147
crop_rect = (
148
min(image.size[0] - (tile_size + original_image_slice), x * tile_size),
149
min(image.size[1] - (tile_size + original_image_slice), y * tile_size),
150
min(image.size[0], (x + 1) * tile_size),
151
min(image.size[1], (y + 1) * tile_size),
152
)
153
crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)
154
tile = image.crop(crop_rect_with_overlap)
155
translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]
156
translated_slice_x = translated_slice_x - (original_image_slice / 2)
157
translated_slice_x = max(0, translated_slice_x)
158
to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)
159
orig_input_size = to_input.size
160
to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)
161
upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]
162
upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)
163
upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)
164
upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)
165
remove_borders = []
166
if x == 0:
167
remove_borders.append("l")
168
elif crop_rect[2] == image.size[0]:
169
remove_borders.append("r")
170
if y == 0:
171
remove_borders.append("t")
172
elif crop_rect[3] == image.size[1]:
173
remove_borders.append("b")
174
transparency_mask = Image.fromarray(
175
make_transparency_mask(
176
(upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders
177
),
178
mode="L",
179
)
180
final_image.paste(
181
upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask
182
)
183
184
@torch.no_grad()
185
def __call__(
186
self,
187
prompt: Union[str, List[str]],
188
image: Union[PIL.Image.Image, List[PIL.Image.Image]],
189
num_inference_steps: int = 75,
190
guidance_scale: float = 9.0,
191
noise_level: int = 50,
192
negative_prompt: Optional[Union[str, List[str]]] = None,
193
num_images_per_prompt: Optional[int] = 1,
194
eta: float = 0.0,
195
generator: Optional[torch.Generator] = None,
196
latents: Optional[torch.FloatTensor] = None,
197
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
198
callback_steps: int = 1,
199
tile_size: int = 128,
200
tile_border: int = 32,
201
original_image_slice: int = 32,
202
):
203
r"""
204
Function invoked when calling the pipeline for generation.
205
206
Args:
207
prompt (`str` or `List[str]`):
208
The prompt or prompts to guide the image generation.
209
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
210
`Image`, or tensor representing an image batch which will be upscaled. *
211
num_inference_steps (`int`, *optional*, defaults to 50):
212
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
213
expense of slower inference.
214
guidance_scale (`float`, *optional*, defaults to 7.5):
215
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
216
`guidance_scale` is defined as `w` of equation 2. of [Imagen
217
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
218
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
219
usually at the expense of lower image quality.
220
negative_prompt (`str` or `List[str]`, *optional*):
221
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
222
if `guidance_scale` is less than `1`).
223
num_images_per_prompt (`int`, *optional*, defaults to 1):
224
The number of images to generate per prompt.
225
eta (`float`, *optional*, defaults to 0.0):
226
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
227
[`schedulers.DDIMScheduler`], will be ignored for others.
228
generator (`torch.Generator`, *optional*):
229
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
230
deterministic.
231
latents (`torch.FloatTensor`, *optional*):
232
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
233
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
234
tensor will ge generated by sampling using the supplied random `generator`.
235
tile_size (`int`, *optional*):
236
The size of the tiles. Too big can result in an OOM-error.
237
tile_border (`int`, *optional*):
238
The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).
239
original_image_slice (`int`, *optional*):
240
The amount of pixels of the original image to calculate with the current tile (bigger means more depth
241
is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).
242
callback (`Callable`, *optional*):
243
A function that take a callback function with a single argument, a dict,
244
that contains the (partially) processed image under "image",
245
as well as the progress (0 to 1, where 1 is completed) under "progress".
246
247
Returns: A PIL.Image that is 4 times larger than the original input image.
248
249
"""
250
251
final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))
252
tcx = math.ceil(image.size[0] / tile_size)
253
tcy = math.ceil(image.size[1] / tile_size)
254
total_tile_count = tcx * tcy
255
current_count = 0
256
for y in range(tcy):
257
for x in range(tcx):
258
self._process_tile(
259
original_image_slice,
260
x,
261
y,
262
tile_size,
263
tile_border,
264
image,
265
final_image,
266
prompt=prompt,
267
num_inference_steps=num_inference_steps,
268
guidance_scale=guidance_scale,
269
noise_level=noise_level,
270
negative_prompt=negative_prompt,
271
num_images_per_prompt=num_images_per_prompt,
272
eta=eta,
273
generator=generator,
274
latents=latents,
275
)
276
current_count += 1
277
if callback is not None:
278
callback({"progress": current_count / total_tile_count, "image": final_image})
279
return final_image
280
281
282
def main():
283
# Run a demo
284
model_id = "stabilityai/stable-diffusion-x4-upscaler"
285
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
286
pipe = pipe.to("cuda")
287
image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
288
289
def callback(obj):
290
print(f"progress: {obj['progress']:.4f}")
291
obj["image"].save("diffusers_library_progress.jpg")
292
293
final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)
294
final_image.save("diffusers_library.jpg")
295
296
297
if __name__ == "__main__":
298
main()
299
300