Path: blob/main/examples/community/tiled_upscaling.py
1448 views
# Copyright 2023 Peter Willemsen <[email protected]>. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import math15from typing import Callable, List, Optional, Union1617import numpy as np18import PIL19import torch20from PIL import Image21from transformers import CLIPTextModel, CLIPTokenizer2223from diffusers.models import AutoencoderKL, UNet2DConditionModel24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline25from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler262728def make_transparency_mask(size, overlap_pixels, remove_borders=[]):29size_x = size[0] - overlap_pixels * 230size_y = size[1] - overlap_pixels * 231for letter in ["l", "r"]:32if letter in remove_borders:33size_x += overlap_pixels34for letter in ["t", "b"]:35if letter in remove_borders:36size_y += overlap_pixels37mask = np.ones((size_y, size_x), dtype=np.uint8) * 25538mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)3940if "l" in remove_borders:41mask = mask[:, overlap_pixels : mask.shape[1]]42if "r" in remove_borders:43mask = mask[:, 0 : mask.shape[1] - overlap_pixels]44if "t" in remove_borders:45mask = mask[overlap_pixels : mask.shape[0], :]46if "b" in remove_borders:47mask = mask[0 : mask.shape[0] - overlap_pixels, :]48return mask495051def clamp(n, smallest, largest):52return max(smallest, min(n, largest))535455def clamp_rect(rect: [int], min: [int], max: [int]):56return (57clamp(rect[0], min[0], max[0]),58clamp(rect[1], min[1], max[1]),59clamp(rect[2], min[0], max[0]),60clamp(rect[3], min[1], max[1]),61)626364def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):65rect = list(rect)66rect[0] -= overlap67rect[1] -= overlap68rect[2] += overlap69rect[3] += overlap70rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])71return rect727374def squeeze_tile(tile, original_image, original_slice, slice_x):75result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))76result.paste(77original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(78(slice_x, 0, slice_x + original_slice, tile.size[1])79),80(0, 0),81)82result.paste(tile, (original_slice, 0))83return result848586def unsqueeze_tile(tile, original_image_slice):87crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])88tile = tile.crop(crop_rect)89return tile909192def next_divisible(n, d):93divisor = n % d94return n - divisor959697class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):98r"""99Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute100to create gigantic images.101102This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the103library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)104105Args:106vae ([`AutoencoderKL`]):107Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.108text_encoder ([`CLIPTextModel`]):109Frozen text-encoder. Stable Diffusion uses the text portion of110[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically111the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.112tokenizer (`CLIPTokenizer`):113Tokenizer of class114[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).115unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.116low_res_scheduler ([`SchedulerMixin`]):117A scheduler used to add initial noise to the low res conditioning image. It must be an instance of118[`DDPMScheduler`].119scheduler ([`SchedulerMixin`]):120A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of121[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].122"""123124def __init__(125self,126vae: AutoencoderKL,127text_encoder: CLIPTextModel,128tokenizer: CLIPTokenizer,129unet: UNet2DConditionModel,130low_res_scheduler: DDPMScheduler,131scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],132max_noise_level: int = 350,133):134super().__init__(135vae=vae,136text_encoder=text_encoder,137tokenizer=tokenizer,138unet=unet,139low_res_scheduler=low_res_scheduler,140scheduler=scheduler,141max_noise_level=max_noise_level,142)143144def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):145torch.manual_seed(0)146crop_rect = (147min(image.size[0] - (tile_size + original_image_slice), x * tile_size),148min(image.size[1] - (tile_size + original_image_slice), y * tile_size),149min(image.size[0], (x + 1) * tile_size),150min(image.size[1], (y + 1) * tile_size),151)152crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)153tile = image.crop(crop_rect_with_overlap)154translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]155translated_slice_x = translated_slice_x - (original_image_slice / 2)156translated_slice_x = max(0, translated_slice_x)157to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)158orig_input_size = to_input.size159to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)160upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]161upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)162upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)163upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)164remove_borders = []165if x == 0:166remove_borders.append("l")167elif crop_rect[2] == image.size[0]:168remove_borders.append("r")169if y == 0:170remove_borders.append("t")171elif crop_rect[3] == image.size[1]:172remove_borders.append("b")173transparency_mask = Image.fromarray(174make_transparency_mask(175(upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders176),177mode="L",178)179final_image.paste(180upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask181)182183@torch.no_grad()184def __call__(185self,186prompt: Union[str, List[str]],187image: Union[PIL.Image.Image, List[PIL.Image.Image]],188num_inference_steps: int = 75,189guidance_scale: float = 9.0,190noise_level: int = 50,191negative_prompt: Optional[Union[str, List[str]]] = None,192num_images_per_prompt: Optional[int] = 1,193eta: float = 0.0,194generator: Optional[torch.Generator] = None,195latents: Optional[torch.FloatTensor] = None,196callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,197callback_steps: int = 1,198tile_size: int = 128,199tile_border: int = 32,200original_image_slice: int = 32,201):202r"""203Function invoked when calling the pipeline for generation.204205Args:206prompt (`str` or `List[str]`):207The prompt or prompts to guide the image generation.208image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):209`Image`, or tensor representing an image batch which will be upscaled. *210num_inference_steps (`int`, *optional*, defaults to 50):211The number of denoising steps. More denoising steps usually lead to a higher quality image at the212expense of slower inference.213guidance_scale (`float`, *optional*, defaults to 7.5):214Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).215`guidance_scale` is defined as `w` of equation 2. of [Imagen216Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >2171`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,218usually at the expense of lower image quality.219negative_prompt (`str` or `List[str]`, *optional*):220The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored221if `guidance_scale` is less than `1`).222num_images_per_prompt (`int`, *optional*, defaults to 1):223The number of images to generate per prompt.224eta (`float`, *optional*, defaults to 0.0):225Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to226[`schedulers.DDIMScheduler`], will be ignored for others.227generator (`torch.Generator`, *optional*):228A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation229deterministic.230latents (`torch.FloatTensor`, *optional*):231Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image232generation. Can be used to tweak the same generation with different prompts. If not provided, a latents233tensor will ge generated by sampling using the supplied random `generator`.234tile_size (`int`, *optional*):235The size of the tiles. Too big can result in an OOM-error.236tile_border (`int`, *optional*):237The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).238original_image_slice (`int`, *optional*):239The amount of pixels of the original image to calculate with the current tile (bigger means more depth240is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).241callback (`Callable`, *optional*):242A function that take a callback function with a single argument, a dict,243that contains the (partially) processed image under "image",244as well as the progress (0 to 1, where 1 is completed) under "progress".245246Returns: A PIL.Image that is 4 times larger than the original input image.247248"""249250final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))251tcx = math.ceil(image.size[0] / tile_size)252tcy = math.ceil(image.size[1] / tile_size)253total_tile_count = tcx * tcy254current_count = 0255for y in range(tcy):256for x in range(tcx):257self._process_tile(258original_image_slice,259x,260y,261tile_size,262tile_border,263image,264final_image,265prompt=prompt,266num_inference_steps=num_inference_steps,267guidance_scale=guidance_scale,268noise_level=noise_level,269negative_prompt=negative_prompt,270num_images_per_prompt=num_images_per_prompt,271eta=eta,272generator=generator,273latents=latents,274)275current_count += 1276if callback is not None:277callback({"progress": current_count / total_tile_count, "image": final_image})278return final_image279280281def main():282# Run a demo283model_id = "stabilityai/stable-diffusion-x4-upscaler"284pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)285pipe = pipe.to("cuda")286image = Image.open("../../docs/source/imgs/diffusers_library.jpg")287288def callback(obj):289print(f"progress: {obj['progress']:.4f}")290obj["image"].save("diffusers_library_progress.jpg")291292final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)293final_image.save("diffusers_library.jpg")294295296if __name__ == "__main__":297main()298299300