Path: blob/master/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
2310 views
import numpy as np1import gradio as gr2import math3from modules.ui_components import InputAccordion4import modules.scripts as scripts5from modules.torch_utils import float64678class SoftInpaintingSettings:9def __init__(self,10mask_blend_power,11mask_blend_scale,12inpaint_detail_preservation,13composite_mask_influence,14composite_difference_threshold,15composite_difference_contrast):16self.mask_blend_power = mask_blend_power17self.mask_blend_scale = mask_blend_scale18self.inpaint_detail_preservation = inpaint_detail_preservation19self.composite_mask_influence = composite_mask_influence20self.composite_difference_threshold = composite_difference_threshold21self.composite_difference_contrast = composite_difference_contrast2223def add_generation_params(self, dest):24dest[enabled_gen_param_label] = True25dest[gen_param_labels.mask_blend_power] = self.mask_blend_power26dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale27dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation28dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence29dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold30dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast313233# ------------------- Methods -------------------3435def processing_uses_inpainting(p):36# TODO: Figure out a better way to determine if inpainting is being used by p37if getattr(p, "image_mask", None) is not None:38return True3940if getattr(p, "mask", None) is not None:41return True4243if getattr(p, "nmask", None) is not None:44return True4546return False474849def latent_blend(settings, a, b, t):50"""51Interpolates two latent image representations according to the parameter t,52where the interpolated vectors' magnitudes are also interpolated separately.53The "detail_preservation" factor biases the magnitude interpolation towards54the larger of the two magnitudes.55"""56import torch5758# NOTE: We use inplace operations wherever possible.5960if len(t.shape) == 3:61# [4][w][h] to [1][4][w][h]62t2 = t.unsqueeze(0)63# [4][w][h] to [1][1][w][h] - the [4] seem redundant.64t3 = t[0].unsqueeze(0).unsqueeze(0)65else:66t2 = t67t3 = t[:, 0][:, None]6869one_minus_t2 = 1 - t270one_minus_t3 = 1 - t37172# Linearly interpolate the image vectors.73a_scaled = a * one_minus_t274b_scaled = b * t275image_interp = a_scaled76image_interp.add_(b_scaled)77result_type = image_interp.dtype78del a_scaled, b_scaled, t2, one_minus_t27980# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)81# 64-bit operations are used here to allow large exponents.82current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(float64(image_interp)).add_(0.00001)8384# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).85a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(float64(a)).pow_(settings.inpaint_detail_preservation) * one_minus_t386b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(float64(b)).pow_(settings.inpaint_detail_preservation) * t387desired_magnitude = a_magnitude88desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)89del a_magnitude, b_magnitude, t3, one_minus_t39091# Change the linearly interpolated image vectors' magnitudes to the value we want.92# This is the last 64-bit operation.93image_interp_scaling_factor = desired_magnitude94image_interp_scaling_factor.div_(current_magnitude)95image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)96image_interp_scaled = image_interp97image_interp_scaled.mul_(image_interp_scaling_factor)98del current_magnitude99del desired_magnitude100del image_interp101del image_interp_scaling_factor102del result_type103104return image_interp_scaled105106107def get_modified_nmask(settings, nmask, sigma):108"""109Converts a negative mask representing the transparency of the original latent vectors being overlaid110to a mask that is scaled according to the denoising strength for this step.111112Where:1130 = fully opaque, infinite density, fully masked1141 = fully transparent, zero density, fully unmasked115116We bring this transparency to a power, as this allows one to simulate N number of blending operations117where N can be any positive real value. Using this one can control the balance of influence between118the denoiser and the original latents according to the sigma value.119120NOTE: "mask" is not used121"""122import torch123return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)124125126def apply_adaptive_masks(127settings: SoftInpaintingSettings,128nmask,129latent_orig,130latent_processed,131overlay_images,132width, height,133paste_to):134import torch135import modules.processing as proc136import modules.images as images137from PIL import Image, ImageOps, ImageFilter138139# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.140if len(nmask.shape) == 3:141latent_mask = nmask[0].float()142else:143latent_mask = nmask[:, 0].float()144# convert the original mask into a form we use to scale distances for thresholding145mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))146mask_scalar = (0.5 * (1 - settings.composite_mask_influence)147+ mask_scalar * settings.composite_mask_influence)148mask_scalar = mask_scalar / (1.00001 - mask_scalar)149mask_scalar = mask_scalar.cpu().numpy()150151latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)152153kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)154155masks_for_overlay = []156157for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):158converted_mask = distance_map.float().cpu().numpy()159converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,160percentile_min=0.9, percentile_max=1, min_width=1)161converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,162percentile_min=0.25, percentile_max=0.75, min_width=1)163164# The distance at which opacity of original decreases to 50%165if len(mask_scalar.shape) == 3:166if mask_scalar.shape[0] > i:167half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i]168else:169half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0]170else:171half_weighted_distance = settings.composite_difference_threshold * mask_scalar172173converted_mask = converted_mask / half_weighted_distance174175converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)176converted_mask = smootherstep(converted_mask)177converted_mask = 1 - converted_mask178converted_mask = 255. * converted_mask179converted_mask = converted_mask.astype(np.uint8)180converted_mask = Image.fromarray(converted_mask)181converted_mask = images.resize_image(2, converted_mask, width, height)182converted_mask = proc.create_binary_mask(converted_mask, round=False)183184# Remove aliasing artifacts using a gaussian blur.185converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))186187# Expand the mask to fit the whole image if needed.188if paste_to is not None:189converted_mask = proc.uncrop(converted_mask,190(overlay_image.width, overlay_image.height),191paste_to)192193masks_for_overlay.append(converted_mask)194195image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))196image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),197mask=ImageOps.invert(converted_mask.convert('L')))198199overlay_images[i] = image_masked.convert('RGBA')200201return masks_for_overlay202203204def apply_masks(205settings,206nmask,207overlay_images,208width, height,209paste_to):210import torch211import modules.processing as proc212import modules.images as images213from PIL import Image, ImageOps, ImageFilter214215converted_mask = nmask[0].float()216converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)217converted_mask = 255. * converted_mask218converted_mask = converted_mask.cpu().numpy().astype(np.uint8)219converted_mask = Image.fromarray(converted_mask)220converted_mask = images.resize_image(2, converted_mask, width, height)221converted_mask = proc.create_binary_mask(converted_mask, round=False)222223# Remove aliasing artifacts using a gaussian blur.224converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))225226# Expand the mask to fit the whole image if needed.227if paste_to is not None:228converted_mask = proc.uncrop(converted_mask,229(width, height),230paste_to)231232masks_for_overlay = []233234for i, overlay_image in enumerate(overlay_images):235masks_for_overlay[i] = converted_mask236237image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))238image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),239mask=ImageOps.invert(converted_mask.convert('L')))240241overlay_images[i] = image_masked.convert('RGBA')242243return masks_for_overlay244245246def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):247"""248Generalization convolution filter capable of applying249weighted mean, median, maximum, and minimum filters250parametrically using an arbitrary kernel.251252Args:253img (nparray):254The image, a 2-D array of floats, to which the filter is being applied.255kernel (nparray):256The kernel, a 2-D array of floats.257kernel_center (nparray):258The kernel center coordinate, a 1-D array with two elements.259percentile_min (float):260The lower bound of the histogram window used by the filter,261from 0 to 1.262percentile_max (float):263The upper bound of the histogram window used by the filter,264from 0 to 1.265min_width (float):266The minimum size of the histogram window bounds, in weight units.267Must be greater than 0.268269Returns:270(nparray): A filtered copy of the input image "img", a 2-D array of floats.271"""272273# Converts an index tuple into a vector.274def vec(x):275return np.array(x)276277kernel_min = -kernel_center278kernel_max = vec(kernel.shape) - kernel_center279280def weighted_histogram_filter_single(idx):281idx = vec(idx)282min_index = np.maximum(0, idx + kernel_min)283max_index = np.minimum(vec(img.shape), idx + kernel_max)284window_shape = max_index - min_index285286class WeightedElement:287"""288An element of the histogram, its weight289and bounds.290"""291292def __init__(self, value, weight):293self.value: float = value294self.weight: float = weight295self.window_min: float = 0.0296self.window_max: float = 1.0297298# Collect the values in the image as WeightedElements,299# weighted by their corresponding kernel values.300values = []301for window_tup in np.ndindex(tuple(window_shape)):302window_index = vec(window_tup)303image_index = window_index + min_index304centered_kernel_index = image_index - idx305kernel_index = centered_kernel_index + kernel_center306element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])307values.append(element)308309def sort_key(x: WeightedElement):310return x.value311312values.sort(key=sort_key)313314# Calculate the height of the stack (sum)315# and each sample's range they occupy in the stack316sum = 0317for i in range(len(values)):318values[i].window_min = sum319sum += values[i].weight320values[i].window_max = sum321322# Calculate what range of this stack ("window")323# we want to get the weighted average across.324window_min = sum * percentile_min325window_max = sum * percentile_max326window_width = window_max - window_min327328# Ensure the window is within the stack and at least a certain size.329if window_width < min_width:330window_center = (window_min + window_max) / 2331window_min = window_center - min_width / 2332window_max = window_center + min_width / 2333334if window_max > sum:335window_max = sum336window_min = sum - min_width337338if window_min < 0:339window_min = 0340window_max = min_width341342value = 0343value_weight = 0344345# Get the weighted average of all the samples346# that overlap with the window, weighted347# by the size of their overlap.348for i in range(len(values)):349if window_min >= values[i].window_max:350continue351if window_max <= values[i].window_min:352break353354s = max(window_min, values[i].window_min)355e = min(window_max, values[i].window_max)356w = e - s357358value += values[i].value * w359value_weight += w360361return value / value_weight if value_weight != 0 else 0362363img_out = img.copy()364365# Apply the kernel operation over each pixel.366for index in np.ndindex(img.shape):367img_out[index] = weighted_histogram_filter_single(index)368369return img_out370371372def smoothstep(x):373"""374The smoothstep function, input should be clamped to 0-1 range.375Turns a diagonal line (f(x) = x) into a sigmoid-like curve.376"""377return x * x * (3 - 2 * x)378379380def smootherstep(x):381"""382The smootherstep function, input should be clamped to 0-1 range.383Turns a diagonal line (f(x) = x) into a sigmoid-like curve.384"""385return x * x * x * (x * (6 * x - 15) + 10)386387388def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):389"""390Creates a Gaussian kernel with thresholded edges.391392Args:393stddev_radius (float):394Standard deviation of the gaussian kernel, in pixels.395max_radius (int):396The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.397The kernel is thresholded so that any values one pixel beyond this radius398is weighted at 0.399400Returns:401(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))402"""403404# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.405def gaussian(sqr_mag):406return math.exp(-sqr_mag / (stddev_radius * stddev_radius))407408# Helper function for converting a tuple to an array.409def vec(x):410return np.array(x)411412"""413Since a gaussian is unbounded, we need to limit ourselves414to a finite range.415We taper the ends off at the end of that range so they equal zero416while preserving the maximum value of 1 at the mean.417"""418zero_radius = max_radius + 1.0419gauss_zero = gaussian(zero_radius * zero_radius)420gauss_kernel_scale = 1 / (1 - gauss_zero)421422def gaussian_kernel_func(coordinate):423x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0424x = gaussian(x)425x -= gauss_zero426x *= gauss_kernel_scale427x = max(0.0, x)428return x429430size = max_radius * 2 + 1431kernel_center = max_radius432kernel = np.zeros((size, size))433434for index in np.ndindex(kernel.shape):435kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)436437return kernel, kernel_center438439440# ------------------- Constants -------------------441442443default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)444445enabled_ui_label = "Soft inpainting"446enabled_gen_param_label = "Soft inpainting enabled"447enabled_el_id = "soft_inpainting_enabled"448449ui_labels = SoftInpaintingSettings(450"Schedule bias",451"Preservation strength",452"Transition contrast boost",453"Mask influence",454"Difference threshold",455"Difference contrast")456457ui_info = SoftInpaintingSettings(458"Shifts when preservation of original content occurs during denoising.",459"How strongly partially masked content should be preserved.",460"Amplifies the contrast that may be lost in partially masked regions.",461"How strongly the original mask should bias the difference threshold.",462"How much an image region can change before the original pixels are not blended in anymore.",463"How sharp the transition should be between blended and not blended.")464465gen_param_labels = SoftInpaintingSettings(466"Soft inpainting schedule bias",467"Soft inpainting preservation strength",468"Soft inpainting transition contrast boost",469"Soft inpainting mask influence",470"Soft inpainting difference threshold",471"Soft inpainting difference contrast")472473el_ids = SoftInpaintingSettings(474"mask_blend_power",475"mask_blend_scale",476"inpaint_detail_preservation",477"composite_mask_influence",478"composite_difference_threshold",479"composite_difference_contrast")480481482# ------------------- Script -------------------483484485class Script(scripts.Script):486def __init__(self):487self.section = "inpaint"488self.masks_for_overlay = None489self.overlay_images = None490491def title(self):492return "Soft Inpainting"493494def show(self, is_img2img):495return scripts.AlwaysVisible if is_img2img else False496497def ui(self, is_img2img):498if not is_img2img:499return500501with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:502with gr.Group():503gr.Markdown(504"""505Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.506**High _Mask blur_** values are recommended!507""")508509power = \510gr.Slider(label=ui_labels.mask_blend_power,511info=ui_info.mask_blend_power,512minimum=0,513maximum=8,514step=0.1,515value=default.mask_blend_power,516elem_id=el_ids.mask_blend_power)517scale = \518gr.Slider(label=ui_labels.mask_blend_scale,519info=ui_info.mask_blend_scale,520minimum=0,521maximum=8,522step=0.05,523value=default.mask_blend_scale,524elem_id=el_ids.mask_blend_scale)525detail = \526gr.Slider(label=ui_labels.inpaint_detail_preservation,527info=ui_info.inpaint_detail_preservation,528minimum=1,529maximum=32,530step=0.5,531value=default.inpaint_detail_preservation,532elem_id=el_ids.inpaint_detail_preservation)533534gr.Markdown(535"""536### Pixel Composite Settings537""")538539mask_inf = \540gr.Slider(label=ui_labels.composite_mask_influence,541info=ui_info.composite_mask_influence,542minimum=0,543maximum=1,544step=0.05,545value=default.composite_mask_influence,546elem_id=el_ids.composite_mask_influence)547548dif_thresh = \549gr.Slider(label=ui_labels.composite_difference_threshold,550info=ui_info.composite_difference_threshold,551minimum=0,552maximum=8,553step=0.25,554value=default.composite_difference_threshold,555elem_id=el_ids.composite_difference_threshold)556557dif_contr = \558gr.Slider(label=ui_labels.composite_difference_contrast,559info=ui_info.composite_difference_contrast,560minimum=0,561maximum=8,562step=0.25,563value=default.composite_difference_contrast,564elem_id=el_ids.composite_difference_contrast)565566with gr.Accordion("Help", open=False):567gr.Markdown(568f"""569### {ui_labels.mask_blend_power}570571The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).572This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.573This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.574575- **Below 1**: Stronger preservation near the end (with low sigma)576- **1**: Balanced (proportional to sigma)577- **Above 1**: Stronger preservation in the beginning (with high sigma)578""")579gr.Markdown(580f"""581### {ui_labels.mask_blend_scale}582583Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.584This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.585586- **Low values**: Favors generated content.587- **High values**: Favors original content.588""")589gr.Markdown(590f"""591### {ui_labels.inpaint_detail_preservation}592593This parameter controls how the original latent vectors and denoised latent vectors are interpolated.594With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.595This can prevent the loss of contrast that occurs with linear interpolation.596597- **Low values**: Softer blending, details may fade.598- **High values**: Stronger contrast, may over-saturate colors.599""")600601gr.Markdown(602"""603## Pixel Composite Settings604605Masks are generated based on how much a part of the image changed after denoising.606These masks are used to blend the original and final images together.607If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.608""")609610gr.Markdown(611f"""612### {ui_labels.composite_mask_influence}613614This parameter controls how much the mask should bias this sensitivity to difference.615616- **0**: Ignore the mask, only consider differences in image content.617- **1**: Follow the mask closely despite image content changes.618""")619620gr.Markdown(621f"""622### {ui_labels.composite_difference_threshold}623624This value represents the difference at which the original pixels will have less than 50% opacity.625626- **Low values**: Two images patches must be almost the same in order to retain original pixels.627- **High values**: Two images patches can be very different and still retain original pixels.628""")629630gr.Markdown(631f"""632### {ui_labels.composite_difference_contrast}633634This value represents the contrast between the opacity of the original and inpainted content.635636- **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.637- **High values**: Ghosting will be less common, but transitions may be very sudden.638""")639640self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),641(power, gen_param_labels.mask_blend_power),642(scale, gen_param_labels.mask_blend_scale),643(detail, gen_param_labels.inpaint_detail_preservation),644(mask_inf, gen_param_labels.composite_mask_influence),645(dif_thresh, gen_param_labels.composite_difference_threshold),646(dif_contr, gen_param_labels.composite_difference_contrast)]647648self.paste_field_names = []649for _, field_name in self.infotext_fields:650self.paste_field_names.append(field_name)651652return [soft_inpainting_enabled,653power,654scale,655detail,656mask_inf,657dif_thresh,658dif_contr]659660def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):661if not enabled:662return663664if not processing_uses_inpainting(p):665return666667# Shut off the rounding it normally does.668p.mask_round = False669670settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)671672# p.extra_generation_params["Mask rounding"] = False673settings.add_generation_params(p.extra_generation_params)674675def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,676dif_thresh, dif_contr):677if not enabled:678return679680if not processing_uses_inpainting(p):681return682683if mba.is_final_blend:684mba.blended_latent = mba.current_latent685return686687settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)688689# todo: Why is sigma 2D? Both values are the same.690mba.blended_latent = latent_blend(settings,691mba.init_latent,692mba.current_latent,693get_modified_nmask(settings, mba.nmask, mba.sigma[0]))694695def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,696dif_thresh, dif_contr):697if not enabled:698return699700if not processing_uses_inpainting(p):701return702703nmask = getattr(p, "nmask", None)704if nmask is None:705return706707from modules import images708from modules.shared import opts709710settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)711712# since the original code puts holes in the existing overlay images,713# we have to rebuild them.714self.overlay_images = []715for img in p.init_images:716717image = images.flatten(img, opts.img2img_background_color)718719if p.paste_to is None and p.resize_mode != 3:720image = images.resize_image(p.resize_mode, image, p.width, p.height)721722self.overlay_images.append(image.convert('RGBA'))723724if len(p.init_images) == 1:725self.overlay_images = self.overlay_images * p.batch_size726727if getattr(ps.samples, 'already_decoded', False):728self.masks_for_overlay = apply_masks(settings=settings,729nmask=nmask,730overlay_images=self.overlay_images,731width=p.width,732height=p.height,733paste_to=p.paste_to)734else:735self.masks_for_overlay = apply_adaptive_masks(settings=settings,736nmask=nmask,737latent_orig=p.init_latent,738latent_processed=ps.samples,739overlay_images=self.overlay_images,740width=p.width,741height=p.height,742paste_to=p.paste_to)743744def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,745detail_preservation, mask_inf, dif_thresh, dif_contr):746if not enabled:747return748749if not processing_uses_inpainting(p):750return751752if self.masks_for_overlay is None:753return754755if self.overlay_images is None:756return757758ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]759ppmo.overlay_image = self.overlay_images[ppmo.index]760761762