Path: blob/main/examples/community/lpw_stable_diffusion_onnx.py
1448 views
import inspect1import re2from typing import Callable, List, Optional, Union34import numpy as np5import PIL6import torch7from packaging import version8from transformers import CLIPImageProcessor, CLIPTokenizer910import diffusers11from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin12from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput13from diffusers.utils import logging141516try:17from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE18except ImportError:19ORT_TO_NP_TYPE = {20"tensor(bool)": np.bool_,21"tensor(int8)": np.int8,22"tensor(uint8)": np.uint8,23"tensor(int16)": np.int16,24"tensor(uint16)": np.uint16,25"tensor(int32)": np.int32,26"tensor(uint32)": np.uint32,27"tensor(int64)": np.int64,28"tensor(uint64)": np.uint64,29"tensor(float16)": np.float16,30"tensor(float)": np.float32,31"tensor(double)": np.float64,32}3334try:35from diffusers.utils import PIL_INTERPOLATION36except ImportError:37if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):38PIL_INTERPOLATION = {39"linear": PIL.Image.Resampling.BILINEAR,40"bilinear": PIL.Image.Resampling.BILINEAR,41"bicubic": PIL.Image.Resampling.BICUBIC,42"lanczos": PIL.Image.Resampling.LANCZOS,43"nearest": PIL.Image.Resampling.NEAREST,44}45else:46PIL_INTERPOLATION = {47"linear": PIL.Image.LINEAR,48"bilinear": PIL.Image.BILINEAR,49"bicubic": PIL.Image.BICUBIC,50"lanczos": PIL.Image.LANCZOS,51"nearest": PIL.Image.NEAREST,52}53# ------------------------------------------------------------------------------5455logger = logging.get_logger(__name__) # pylint: disable=invalid-name5657re_attention = re.compile(58r"""59\\\(|60\\\)|61\\\[|62\\]|63\\\\|64\\|65\(|66\[|67:([+-]?[.\d]+)\)|68\)|69]|70[^\\()\[\]:]+|71:72""",73re.X,74)757677def parse_prompt_attention(text):78"""79Parses a string with attention tokens and returns a list of pairs: text and its associated weight.80Accepted tokens are:81(abc) - increases attention to abc by a multiplier of 1.182(abc:3.12) - increases attention to abc by a multiplier of 3.1283[abc] - decreases attention to abc by a multiplier of 1.184\( - literal character '('85\[ - literal character '['86\) - literal character ')'87\] - literal character ']'88\\ - literal character '\'89anything else - just text90>>> parse_prompt_attention('normal text')91[['normal text', 1.0]]92>>> parse_prompt_attention('an (important) word')93[['an ', 1.0], ['important', 1.1], [' word', 1.0]]94>>> parse_prompt_attention('(unbalanced')95[['unbalanced', 1.1]]96>>> parse_prompt_attention('\(literal\]')97[['(literal]', 1.0]]98>>> parse_prompt_attention('(unnecessary)(parens)')99[['unnecessaryparens', 1.1]]100>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')101[['a ', 1.0],102['house', 1.5730000000000004],103[' ', 1.1],104['on', 1.0],105[' a ', 1.1],106['hill', 0.55],107[', sun, ', 1.1],108['sky', 1.4641000000000006],109['.', 1.1]]110"""111112res = []113round_brackets = []114square_brackets = []115116round_bracket_multiplier = 1.1117square_bracket_multiplier = 1 / 1.1118119def multiply_range(start_position, multiplier):120for p in range(start_position, len(res)):121res[p][1] *= multiplier122123for m in re_attention.finditer(text):124text = m.group(0)125weight = m.group(1)126127if text.startswith("\\"):128res.append([text[1:], 1.0])129elif text == "(":130round_brackets.append(len(res))131elif text == "[":132square_brackets.append(len(res))133elif weight is not None and len(round_brackets) > 0:134multiply_range(round_brackets.pop(), float(weight))135elif text == ")" and len(round_brackets) > 0:136multiply_range(round_brackets.pop(), round_bracket_multiplier)137elif text == "]" and len(square_brackets) > 0:138multiply_range(square_brackets.pop(), square_bracket_multiplier)139else:140res.append([text, 1.0])141142for pos in round_brackets:143multiply_range(pos, round_bracket_multiplier)144145for pos in square_brackets:146multiply_range(pos, square_bracket_multiplier)147148if len(res) == 0:149res = [["", 1.0]]150151# merge runs of identical weights152i = 0153while i + 1 < len(res):154if res[i][1] == res[i + 1][1]:155res[i][0] += res[i + 1][0]156res.pop(i + 1)157else:158i += 1159160return res161162163def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):164r"""165Tokenize a list of prompts and return its tokens with weights of each token.166167No padding, starting or ending token is included.168"""169tokens = []170weights = []171truncated = False172for text in prompt:173texts_and_weights = parse_prompt_attention(text)174text_token = []175text_weight = []176for word, weight in texts_and_weights:177# tokenize and discard the starting and the ending token178token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]179text_token += list(token)180# copy the weight by length of token181text_weight += [weight] * len(token)182# stop if the text is too long (longer than truncation limit)183if len(text_token) > max_length:184truncated = True185break186# truncate187if len(text_token) > max_length:188truncated = True189text_token = text_token[:max_length]190text_weight = text_weight[:max_length]191tokens.append(text_token)192weights.append(text_weight)193if truncated:194logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")195return tokens, weights196197198def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):199r"""200Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.201"""202max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)203weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length204for i in range(len(tokens)):205tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))206if no_boseos_middle:207weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))208else:209w = []210if len(weights[i]) == 0:211w = [1.0] * weights_length212else:213for j in range(max_embeddings_multiples):214w.append(1.0) # weight for starting token in this chunk215w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]216w.append(1.0) # weight for ending token in this chunk217w += [1.0] * (weights_length - len(w))218weights[i] = w[:]219220return tokens, weights221222223def get_unweighted_text_embeddings(224pipe,225text_input: np.array,226chunk_length: int,227no_boseos_middle: Optional[bool] = True,228):229"""230When the length of tokens is a multiple of the capacity of the text encoder,231it should be split into chunks and sent to the text encoder individually.232"""233max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)234if max_embeddings_multiples > 1:235text_embeddings = []236for i in range(max_embeddings_multiples):237# extract the i-th chunk238text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()239240# cover the head and the tail by the starting and the ending tokens241text_input_chunk[:, 0] = text_input[0, 0]242text_input_chunk[:, -1] = text_input[0, -1]243244text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]245246if no_boseos_middle:247if i == 0:248# discard the ending token249text_embedding = text_embedding[:, :-1]250elif i == max_embeddings_multiples - 1:251# discard the starting token252text_embedding = text_embedding[:, 1:]253else:254# discard both starting and ending tokens255text_embedding = text_embedding[:, 1:-1]256257text_embeddings.append(text_embedding)258text_embeddings = np.concatenate(text_embeddings, axis=1)259else:260text_embeddings = pipe.text_encoder(input_ids=text_input)[0]261return text_embeddings262263264def get_weighted_text_embeddings(265pipe,266prompt: Union[str, List[str]],267uncond_prompt: Optional[Union[str, List[str]]] = None,268max_embeddings_multiples: Optional[int] = 4,269no_boseos_middle: Optional[bool] = False,270skip_parsing: Optional[bool] = False,271skip_weighting: Optional[bool] = False,272**kwargs,273):274r"""275Prompts can be assigned with local weights using brackets. For example,276prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',277and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.278279Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.280281Args:282pipe (`OnnxStableDiffusionPipeline`):283Pipe to provide access to the tokenizer and the text encoder.284prompt (`str` or `List[str]`):285The prompt or prompts to guide the image generation.286uncond_prompt (`str` or `List[str]`):287The unconditional prompt or prompts for guide the image generation. If unconditional prompt288is provided, the embeddings of prompt and uncond_prompt are concatenated.289max_embeddings_multiples (`int`, *optional*, defaults to `1`):290The max multiple length of prompt embeddings compared to the max output length of text encoder.291no_boseos_middle (`bool`, *optional*, defaults to `False`):292If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and293ending token in each of the chunk in the middle.294skip_parsing (`bool`, *optional*, defaults to `False`):295Skip the parsing of brackets.296skip_weighting (`bool`, *optional*, defaults to `False`):297Skip the weighting. When the parsing is skipped, it is forced True.298"""299max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2300if isinstance(prompt, str):301prompt = [prompt]302303if not skip_parsing:304prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)305if uncond_prompt is not None:306if isinstance(uncond_prompt, str):307uncond_prompt = [uncond_prompt]308uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)309else:310prompt_tokens = [311token[1:-1]312for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids313]314prompt_weights = [[1.0] * len(token) for token in prompt_tokens]315if uncond_prompt is not None:316if isinstance(uncond_prompt, str):317uncond_prompt = [uncond_prompt]318uncond_tokens = [319token[1:-1]320for token in pipe.tokenizer(321uncond_prompt,322max_length=max_length,323truncation=True,324return_tensors="np",325).input_ids326]327uncond_weights = [[1.0] * len(token) for token in uncond_tokens]328329# round up the longest length of tokens to a multiple of (model_max_length - 2)330max_length = max([len(token) for token in prompt_tokens])331if uncond_prompt is not None:332max_length = max(max_length, max([len(token) for token in uncond_tokens]))333334max_embeddings_multiples = min(335max_embeddings_multiples,336(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,337)338max_embeddings_multiples = max(1, max_embeddings_multiples)339max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2340341# pad the length of tokens and weights342bos = pipe.tokenizer.bos_token_id343eos = pipe.tokenizer.eos_token_id344prompt_tokens, prompt_weights = pad_tokens_and_weights(345prompt_tokens,346prompt_weights,347max_length,348bos,349eos,350no_boseos_middle=no_boseos_middle,351chunk_length=pipe.tokenizer.model_max_length,352)353prompt_tokens = np.array(prompt_tokens, dtype=np.int32)354if uncond_prompt is not None:355uncond_tokens, uncond_weights = pad_tokens_and_weights(356uncond_tokens,357uncond_weights,358max_length,359bos,360eos,361no_boseos_middle=no_boseos_middle,362chunk_length=pipe.tokenizer.model_max_length,363)364uncond_tokens = np.array(uncond_tokens, dtype=np.int32)365366# get the embeddings367text_embeddings = get_unweighted_text_embeddings(368pipe,369prompt_tokens,370pipe.tokenizer.model_max_length,371no_boseos_middle=no_boseos_middle,372)373prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)374if uncond_prompt is not None:375uncond_embeddings = get_unweighted_text_embeddings(376pipe,377uncond_tokens,378pipe.tokenizer.model_max_length,379no_boseos_middle=no_boseos_middle,380)381uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)382383# assign weights to the prompts and normalize in the sense of mean384# TODO: should we normalize by chunk or in a whole (current implementation)?385if (not skip_parsing) and (not skip_weighting):386previous_mean = text_embeddings.mean(axis=(-2, -1))387text_embeddings *= prompt_weights[:, :, None]388text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]389if uncond_prompt is not None:390previous_mean = uncond_embeddings.mean(axis=(-2, -1))391uncond_embeddings *= uncond_weights[:, :, None]392uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]393394# For classifier free guidance, we need to do two forward passes.395# Here we concatenate the unconditional and text embeddings into a single batch396# to avoid doing two forward passes397if uncond_prompt is not None:398return text_embeddings, uncond_embeddings399400return text_embeddings401402403def preprocess_image(image):404w, h = image.size405w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32406image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])407image = np.array(image).astype(np.float32) / 255.0408image = image[None].transpose(0, 3, 1, 2)409return 2.0 * image - 1.0410411412def preprocess_mask(mask, scale_factor=8):413mask = mask.convert("L")414w, h = mask.size415w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32416mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])417mask = np.array(mask).astype(np.float32) / 255.0418mask = np.tile(mask, (4, 1, 1))419mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?420mask = 1 - mask # repaint white, keep black421return mask422423424class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):425r"""426Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing427weighting in prompt.428429This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the430library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)431"""432if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):433434def __init__(435self,436vae_encoder: OnnxRuntimeModel,437vae_decoder: OnnxRuntimeModel,438text_encoder: OnnxRuntimeModel,439tokenizer: CLIPTokenizer,440unet: OnnxRuntimeModel,441scheduler: SchedulerMixin,442safety_checker: OnnxRuntimeModel,443feature_extractor: CLIPImageProcessor,444requires_safety_checker: bool = True,445):446super().__init__(447vae_encoder=vae_encoder,448vae_decoder=vae_decoder,449text_encoder=text_encoder,450tokenizer=tokenizer,451unet=unet,452scheduler=scheduler,453safety_checker=safety_checker,454feature_extractor=feature_extractor,455requires_safety_checker=requires_safety_checker,456)457self.__init__additional__()458459else:460461def __init__(462self,463vae_encoder: OnnxRuntimeModel,464vae_decoder: OnnxRuntimeModel,465text_encoder: OnnxRuntimeModel,466tokenizer: CLIPTokenizer,467unet: OnnxRuntimeModel,468scheduler: SchedulerMixin,469safety_checker: OnnxRuntimeModel,470feature_extractor: CLIPImageProcessor,471):472super().__init__(473vae_encoder=vae_encoder,474vae_decoder=vae_decoder,475text_encoder=text_encoder,476tokenizer=tokenizer,477unet=unet,478scheduler=scheduler,479safety_checker=safety_checker,480feature_extractor=feature_extractor,481)482self.__init__additional__()483484def __init__additional__(self):485self.unet_in_channels = 4486self.vae_scale_factor = 8487488def _encode_prompt(489self,490prompt,491num_images_per_prompt,492do_classifier_free_guidance,493negative_prompt,494max_embeddings_multiples,495):496r"""497Encodes the prompt into text encoder hidden states.498499Args:500prompt (`str` or `list(int)`):501prompt to be encoded502num_images_per_prompt (`int`):503number of images that should be generated per prompt504do_classifier_free_guidance (`bool`):505whether to use classifier free guidance or not506negative_prompt (`str` or `List[str]`):507The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored508if `guidance_scale` is less than `1`).509max_embeddings_multiples (`int`, *optional*, defaults to `3`):510The max multiple length of prompt embeddings compared to the max output length of text encoder.511"""512batch_size = len(prompt) if isinstance(prompt, list) else 1513514if negative_prompt is None:515negative_prompt = [""] * batch_size516elif isinstance(negative_prompt, str):517negative_prompt = [negative_prompt] * batch_size518if batch_size != len(negative_prompt):519raise ValueError(520f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"521f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"522" the batch size of `prompt`."523)524525text_embeddings, uncond_embeddings = get_weighted_text_embeddings(526pipe=self,527prompt=prompt,528uncond_prompt=negative_prompt if do_classifier_free_guidance else None,529max_embeddings_multiples=max_embeddings_multiples,530)531532text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)533if do_classifier_free_guidance:534uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)535text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])536537return text_embeddings538539def check_inputs(self, prompt, height, width, strength, callback_steps):540if not isinstance(prompt, str) and not isinstance(prompt, list):541raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")542543if strength < 0 or strength > 1:544raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")545546if height % 8 != 0 or width % 8 != 0:547raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")548549if (callback_steps is None) or (550callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)551):552raise ValueError(553f"`callback_steps` has to be a positive integer but is {callback_steps} of type"554f" {type(callback_steps)}."555)556557def get_timesteps(self, num_inference_steps, strength, is_text2img):558if is_text2img:559return self.scheduler.timesteps, num_inference_steps560else:561# get the original timestep using init_timestep562offset = self.scheduler.config.get("steps_offset", 0)563init_timestep = int(num_inference_steps * strength) + offset564init_timestep = min(init_timestep, num_inference_steps)565566t_start = max(num_inference_steps - init_timestep + offset, 0)567timesteps = self.scheduler.timesteps[t_start:]568return timesteps, num_inference_steps - t_start569570def run_safety_checker(self, image):571if self.safety_checker is not None:572safety_checker_input = self.feature_extractor(573self.numpy_to_pil(image), return_tensors="np"574).pixel_values.astype(image.dtype)575# There will throw an error if use safety_checker directly and batchsize>1576images, has_nsfw_concept = [], []577for i in range(image.shape[0]):578image_i, has_nsfw_concept_i = self.safety_checker(579clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]580)581images.append(image_i)582has_nsfw_concept.append(has_nsfw_concept_i[0])583image = np.concatenate(images)584else:585has_nsfw_concept = None586return image, has_nsfw_concept587588def decode_latents(self, latents):589latents = 1 / 0.18215 * latents590# image = self.vae_decoder(latent_sample=latents)[0]591# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1592image = np.concatenate(593[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]594)595image = np.clip(image / 2 + 0.5, 0, 1)596image = image.transpose((0, 2, 3, 1))597return image598599def prepare_extra_step_kwargs(self, generator, eta):600# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature601# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.602# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502603# and should be between [0, 1]604605accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())606extra_step_kwargs = {}607if accepts_eta:608extra_step_kwargs["eta"] = eta609610# check if the scheduler accepts generator611accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())612if accepts_generator:613extra_step_kwargs["generator"] = generator614return extra_step_kwargs615616def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):617if image is None:618shape = (619batch_size,620self.unet_in_channels,621height // self.vae_scale_factor,622width // self.vae_scale_factor,623)624625if latents is None:626latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)627else:628if latents.shape != shape:629raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")630631# scale the initial noise by the standard deviation required by the scheduler632latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()633return latents, None, None634else:635init_latents = self.vae_encoder(sample=image)[0]636init_latents = 0.18215 * init_latents637init_latents = np.concatenate([init_latents] * batch_size, axis=0)638init_latents_orig = init_latents639shape = init_latents.shape640641# add noise to latents using the timesteps642noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)643latents = self.scheduler.add_noise(644torch.from_numpy(init_latents), torch.from_numpy(noise), timestep645).numpy()646return latents, init_latents_orig, noise647648@torch.no_grad()649def __call__(650self,651prompt: Union[str, List[str]],652negative_prompt: Optional[Union[str, List[str]]] = None,653image: Union[np.ndarray, PIL.Image.Image] = None,654mask_image: Union[np.ndarray, PIL.Image.Image] = None,655height: int = 512,656width: int = 512,657num_inference_steps: int = 50,658guidance_scale: float = 7.5,659strength: float = 0.8,660num_images_per_prompt: Optional[int] = 1,661eta: float = 0.0,662generator: Optional[torch.Generator] = None,663latents: Optional[np.ndarray] = None,664max_embeddings_multiples: Optional[int] = 3,665output_type: Optional[str] = "pil",666return_dict: bool = True,667callback: Optional[Callable[[int, int, np.ndarray], None]] = None,668is_cancelled_callback: Optional[Callable[[], bool]] = None,669callback_steps: int = 1,670**kwargs,671):672r"""673Function invoked when calling the pipeline for generation.674675Args:676prompt (`str` or `List[str]`):677The prompt or prompts to guide the image generation.678negative_prompt (`str` or `List[str]`, *optional*):679The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored680if `guidance_scale` is less than `1`).681image (`np.ndarray` or `PIL.Image.Image`):682`Image`, or tensor representing an image batch, that will be used as the starting point for the683process.684mask_image (`np.ndarray` or `PIL.Image.Image`):685`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be686replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a687PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should688contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.689height (`int`, *optional*, defaults to 512):690The height in pixels of the generated image.691width (`int`, *optional*, defaults to 512):692The width in pixels of the generated image.693num_inference_steps (`int`, *optional*, defaults to 50):694The number of denoising steps. More denoising steps usually lead to a higher quality image at the695expense of slower inference.696guidance_scale (`float`, *optional*, defaults to 7.5):697Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).698`guidance_scale` is defined as `w` of equation 2. of [Imagen699Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >7001`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,701usually at the expense of lower image quality.702strength (`float`, *optional*, defaults to 0.8):703Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.704`image` will be used as a starting point, adding more noise to it the larger the `strength`. The705number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added706noise will be maximum and the denoising process will run for the full number of iterations specified in707`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.708num_images_per_prompt (`int`, *optional*, defaults to 1):709The number of images to generate per prompt.710eta (`float`, *optional*, defaults to 0.0):711Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to712[`schedulers.DDIMScheduler`], will be ignored for others.713generator (`torch.Generator`, *optional*):714A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation715deterministic.716latents (`np.ndarray`, *optional*):717Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image718generation. Can be used to tweak the same generation with different prompts. If not provided, a latents719tensor will ge generated by sampling using the supplied random `generator`.720max_embeddings_multiples (`int`, *optional*, defaults to `3`):721The max multiple length of prompt embeddings compared to the max output length of text encoder.722output_type (`str`, *optional*, defaults to `"pil"`):723The output format of the generate image. Choose between724[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.725return_dict (`bool`, *optional*, defaults to `True`):726Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a727plain tuple.728callback (`Callable`, *optional*):729A function that will be called every `callback_steps` steps during inference. The function will be730called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.731is_cancelled_callback (`Callable`, *optional*):732A function that will be called every `callback_steps` steps during inference. If the function returns733`True`, the inference will be cancelled.734callback_steps (`int`, *optional*, defaults to 1):735The frequency at which the `callback` function will be called. If not specified, the callback will be736called at every step.737738Returns:739`None` if cancelled by `is_cancelled_callback`,740[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:741[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.742When returning a tuple, the first element is a list with the generated images, and the second element is a743list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"744(nsfw) content, according to the `safety_checker`.745"""746# 0. Default height and width to unet747height = height or self.unet.config.sample_size * self.vae_scale_factor748width = width or self.unet.config.sample_size * self.vae_scale_factor749750# 1. Check inputs. Raise error if not correct751self.check_inputs(prompt, height, width, strength, callback_steps)752753# 2. Define call parameters754batch_size = 1 if isinstance(prompt, str) else len(prompt)755# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)756# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`757# corresponds to doing no classifier free guidance.758do_classifier_free_guidance = guidance_scale > 1.0759760# 3. Encode input prompt761text_embeddings = self._encode_prompt(762prompt,763num_images_per_prompt,764do_classifier_free_guidance,765negative_prompt,766max_embeddings_multiples,767)768dtype = text_embeddings.dtype769770# 4. Preprocess image and mask771if isinstance(image, PIL.Image.Image):772image = preprocess_image(image)773if image is not None:774image = image.astype(dtype)775if isinstance(mask_image, PIL.Image.Image):776mask_image = preprocess_mask(mask_image, self.vae_scale_factor)777if mask_image is not None:778mask = mask_image.astype(dtype)779mask = np.concatenate([mask] * batch_size * num_images_per_prompt)780else:781mask = None782783# 5. set timesteps784self.scheduler.set_timesteps(num_inference_steps)785timestep_dtype = next(786(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"787)788timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]789timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)790latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)791792# 6. Prepare latent variables793latents, init_latents_orig, noise = self.prepare_latents(794image,795latent_timestep,796batch_size * num_images_per_prompt,797height,798width,799dtype,800generator,801latents,802)803804# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline805extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)806807# 8. Denoising loop808for i, t in enumerate(self.progress_bar(timesteps)):809# expand the latents if we are doing classifier free guidance810latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents811latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)812latent_model_input = latent_model_input.numpy()813814# predict the noise residual815noise_pred = self.unet(816sample=latent_model_input,817timestep=np.array([t], dtype=timestep_dtype),818encoder_hidden_states=text_embeddings,819)820noise_pred = noise_pred[0]821822# perform guidance823if do_classifier_free_guidance:824noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)825noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)826827# compute the previous noisy sample x_t -> x_t-1828scheduler_output = self.scheduler.step(829torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs830)831latents = scheduler_output.prev_sample.numpy()832833if mask is not None:834# masking835init_latents_proper = self.scheduler.add_noise(836torch.from_numpy(init_latents_orig),837torch.from_numpy(noise),838t,839).numpy()840latents = (init_latents_proper * mask) + (latents * (1 - mask))841842# call the callback, if provided843if i % callback_steps == 0:844if callback is not None:845callback(i, t, latents)846if is_cancelled_callback is not None and is_cancelled_callback():847return None848849# 9. Post-processing850image = self.decode_latents(latents)851852# 10. Run safety checker853image, has_nsfw_concept = self.run_safety_checker(image)854855# 11. Convert to PIL856if output_type == "pil":857image = self.numpy_to_pil(image)858859if not return_dict:860return image, has_nsfw_concept861862return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)863864def text2img(865self,866prompt: Union[str, List[str]],867negative_prompt: Optional[Union[str, List[str]]] = None,868height: int = 512,869width: int = 512,870num_inference_steps: int = 50,871guidance_scale: float = 7.5,872num_images_per_prompt: Optional[int] = 1,873eta: float = 0.0,874generator: Optional[torch.Generator] = None,875latents: Optional[np.ndarray] = None,876max_embeddings_multiples: Optional[int] = 3,877output_type: Optional[str] = "pil",878return_dict: bool = True,879callback: Optional[Callable[[int, int, np.ndarray], None]] = None,880callback_steps: int = 1,881**kwargs,882):883r"""884Function for text-to-image generation.885Args:886prompt (`str` or `List[str]`):887The prompt or prompts to guide the image generation.888negative_prompt (`str` or `List[str]`, *optional*):889The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored890if `guidance_scale` is less than `1`).891height (`int`, *optional*, defaults to 512):892The height in pixels of the generated image.893width (`int`, *optional*, defaults to 512):894The width in pixels of the generated image.895num_inference_steps (`int`, *optional*, defaults to 50):896The number of denoising steps. More denoising steps usually lead to a higher quality image at the897expense of slower inference.898guidance_scale (`float`, *optional*, defaults to 7.5):899Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).900`guidance_scale` is defined as `w` of equation 2. of [Imagen901Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >9021`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,903usually at the expense of lower image quality.904num_images_per_prompt (`int`, *optional*, defaults to 1):905The number of images to generate per prompt.906eta (`float`, *optional*, defaults to 0.0):907Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to908[`schedulers.DDIMScheduler`], will be ignored for others.909generator (`torch.Generator`, *optional*):910A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation911deterministic.912latents (`np.ndarray`, *optional*):913Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image914generation. Can be used to tweak the same generation with different prompts. If not provided, a latents915tensor will ge generated by sampling using the supplied random `generator`.916max_embeddings_multiples (`int`, *optional*, defaults to `3`):917The max multiple length of prompt embeddings compared to the max output length of text encoder.918output_type (`str`, *optional*, defaults to `"pil"`):919The output format of the generate image. Choose between920[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.921return_dict (`bool`, *optional*, defaults to `True`):922Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a923plain tuple.924callback (`Callable`, *optional*):925A function that will be called every `callback_steps` steps during inference. The function will be926called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.927callback_steps (`int`, *optional*, defaults to 1):928The frequency at which the `callback` function will be called. If not specified, the callback will be929called at every step.930Returns:931[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:932[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.933When returning a tuple, the first element is a list with the generated images, and the second element is a934list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"935(nsfw) content, according to the `safety_checker`.936"""937return self.__call__(938prompt=prompt,939negative_prompt=negative_prompt,940height=height,941width=width,942num_inference_steps=num_inference_steps,943guidance_scale=guidance_scale,944num_images_per_prompt=num_images_per_prompt,945eta=eta,946generator=generator,947latents=latents,948max_embeddings_multiples=max_embeddings_multiples,949output_type=output_type,950return_dict=return_dict,951callback=callback,952callback_steps=callback_steps,953**kwargs,954)955956def img2img(957self,958image: Union[np.ndarray, PIL.Image.Image],959prompt: Union[str, List[str]],960negative_prompt: Optional[Union[str, List[str]]] = None,961strength: float = 0.8,962num_inference_steps: Optional[int] = 50,963guidance_scale: Optional[float] = 7.5,964num_images_per_prompt: Optional[int] = 1,965eta: Optional[float] = 0.0,966generator: Optional[torch.Generator] = None,967max_embeddings_multiples: Optional[int] = 3,968output_type: Optional[str] = "pil",969return_dict: bool = True,970callback: Optional[Callable[[int, int, np.ndarray], None]] = None,971callback_steps: int = 1,972**kwargs,973):974r"""975Function for image-to-image generation.976Args:977image (`np.ndarray` or `PIL.Image.Image`):978`Image`, or ndarray representing an image batch, that will be used as the starting point for the979process.980prompt (`str` or `List[str]`):981The prompt or prompts to guide the image generation.982negative_prompt (`str` or `List[str]`, *optional*):983The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored984if `guidance_scale` is less than `1`).985strength (`float`, *optional*, defaults to 0.8):986Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.987`image` will be used as a starting point, adding more noise to it the larger the `strength`. The988number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added989noise will be maximum and the denoising process will run for the full number of iterations specified in990`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.991num_inference_steps (`int`, *optional*, defaults to 50):992The number of denoising steps. More denoising steps usually lead to a higher quality image at the993expense of slower inference. This parameter will be modulated by `strength`.994guidance_scale (`float`, *optional*, defaults to 7.5):995Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).996`guidance_scale` is defined as `w` of equation 2. of [Imagen997Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >9981`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,999usually at the expense of lower image quality.1000num_images_per_prompt (`int`, *optional*, defaults to 1):1001The number of images to generate per prompt.1002eta (`float`, *optional*, defaults to 0.0):1003Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to1004[`schedulers.DDIMScheduler`], will be ignored for others.1005generator (`torch.Generator`, *optional*):1006A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation1007deterministic.1008max_embeddings_multiples (`int`, *optional*, defaults to `3`):1009The max multiple length of prompt embeddings compared to the max output length of text encoder.1010output_type (`str`, *optional*, defaults to `"pil"`):1011The output format of the generate image. Choose between1012[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.1013return_dict (`bool`, *optional*, defaults to `True`):1014Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a1015plain tuple.1016callback (`Callable`, *optional*):1017A function that will be called every `callback_steps` steps during inference. The function will be1018called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.1019callback_steps (`int`, *optional*, defaults to 1):1020The frequency at which the `callback` function will be called. If not specified, the callback will be1021called at every step.1022Returns:1023[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:1024[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.1025When returning a tuple, the first element is a list with the generated images, and the second element is a1026list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"1027(nsfw) content, according to the `safety_checker`.1028"""1029return self.__call__(1030prompt=prompt,1031negative_prompt=negative_prompt,1032image=image,1033num_inference_steps=num_inference_steps,1034guidance_scale=guidance_scale,1035strength=strength,1036num_images_per_prompt=num_images_per_prompt,1037eta=eta,1038generator=generator,1039max_embeddings_multiples=max_embeddings_multiples,1040output_type=output_type,1041return_dict=return_dict,1042callback=callback,1043callback_steps=callback_steps,1044**kwargs,1045)10461047def inpaint(1048self,1049image: Union[np.ndarray, PIL.Image.Image],1050mask_image: Union[np.ndarray, PIL.Image.Image],1051prompt: Union[str, List[str]],1052negative_prompt: Optional[Union[str, List[str]]] = None,1053strength: float = 0.8,1054num_inference_steps: Optional[int] = 50,1055guidance_scale: Optional[float] = 7.5,1056num_images_per_prompt: Optional[int] = 1,1057eta: Optional[float] = 0.0,1058generator: Optional[torch.Generator] = None,1059max_embeddings_multiples: Optional[int] = 3,1060output_type: Optional[str] = "pil",1061return_dict: bool = True,1062callback: Optional[Callable[[int, int, np.ndarray], None]] = None,1063callback_steps: int = 1,1064**kwargs,1065):1066r"""1067Function for inpaint.1068Args:1069image (`np.ndarray` or `PIL.Image.Image`):1070`Image`, or tensor representing an image batch, that will be used as the starting point for the1071process. This is the image whose masked region will be inpainted.1072mask_image (`np.ndarray` or `PIL.Image.Image`):1073`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be1074replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a1075PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should1076contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.1077prompt (`str` or `List[str]`):1078The prompt or prompts to guide the image generation.1079negative_prompt (`str` or `List[str]`, *optional*):1080The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored1081if `guidance_scale` is less than `1`).1082strength (`float`, *optional*, defaults to 0.8):1083Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`1084is 1, the denoising process will be run on the masked area for the full number of iterations specified1085in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more1086noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.1087num_inference_steps (`int`, *optional*, defaults to 50):1088The reference number of denoising steps. More denoising steps usually lead to a higher quality image at1089the expense of slower inference. This parameter will be modulated by `strength`, as explained above.1090guidance_scale (`float`, *optional*, defaults to 7.5):1091Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).1092`guidance_scale` is defined as `w` of equation 2. of [Imagen1093Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >10941`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,1095usually at the expense of lower image quality.1096num_images_per_prompt (`int`, *optional*, defaults to 1):1097The number of images to generate per prompt.1098eta (`float`, *optional*, defaults to 0.0):1099Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to1100[`schedulers.DDIMScheduler`], will be ignored for others.1101generator (`torch.Generator`, *optional*):1102A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation1103deterministic.1104max_embeddings_multiples (`int`, *optional*, defaults to `3`):1105The max multiple length of prompt embeddings compared to the max output length of text encoder.1106output_type (`str`, *optional*, defaults to `"pil"`):1107The output format of the generate image. Choose between1108[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.1109return_dict (`bool`, *optional*, defaults to `True`):1110Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a1111plain tuple.1112callback (`Callable`, *optional*):1113A function that will be called every `callback_steps` steps during inference. The function will be1114called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.1115callback_steps (`int`, *optional*, defaults to 1):1116The frequency at which the `callback` function will be called. If not specified, the callback will be1117called at every step.1118Returns:1119[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:1120[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.1121When returning a tuple, the first element is a list with the generated images, and the second element is a1122list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"1123(nsfw) content, according to the `safety_checker`.1124"""1125return self.__call__(1126prompt=prompt,1127negative_prompt=negative_prompt,1128image=image,1129mask_image=mask_image,1130num_inference_steps=num_inference_steps,1131guidance_scale=guidance_scale,1132strength=strength,1133num_images_per_prompt=num_images_per_prompt,1134eta=eta,1135generator=generator,1136max_embeddings_multiples=max_embeddings_multiples,1137output_type=output_type,1138return_dict=return_dict,1139callback=callback,1140callback_steps=callback_steps,1141**kwargs,1142)114311441145