Path: blob/master/scripts/xyz_grid.py
3055 views
from collections import namedtuple1from copy import copy2from itertools import permutations, chain3import random4import csv5import os.path6from io import StringIO7from PIL import Image8import numpy as np910import modules.scripts as scripts11import gradio as gr1213from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors14from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img15from modules.shared import opts, state16import modules.shared as shared17import modules.sd_samplers18import modules.sd_models19import modules.sd_vae20import re2122from modules.ui_components import ToolButton2324fill_values_symbol = "\U0001f4d2" # 📒2526AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])272829def apply_field(field):30def fun(p, x, xs):31setattr(p, field, x)3233return fun343536def apply_prompt(p, x, xs):37if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:38raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")3940p.prompt = p.prompt.replace(xs[0], x)41p.negative_prompt = p.negative_prompt.replace(xs[0], x)424344def apply_order(p, x, xs):45token_order = []4647# Initially grab the tokens from the prompt, so they can be replaced in order of earliest seen48for token in x:49token_order.append((p.prompt.find(token), token))5051token_order.sort(key=lambda t: t[0])5253prompt_parts = []5455# Split the prompt up, taking out the tokens56for _, token in token_order:57n = p.prompt.find(token)58prompt_parts.append(p.prompt[0:n])59p.prompt = p.prompt[n + len(token):]6061# Rebuild the prompt with the tokens in the order we want62prompt_tmp = ""63for idx, part in enumerate(prompt_parts):64prompt_tmp += part65prompt_tmp += x[idx]66p.prompt = prompt_tmp + p.prompt676869def confirm_samplers(p, xs):70for x in xs:71if x.lower() not in sd_samplers.samplers_map:72raise RuntimeError(f"Unknown sampler: {x}")737475def apply_checkpoint(p, x, xs):76info = modules.sd_models.get_closet_checkpoint_match(x)77if info is None:78raise RuntimeError(f"Unknown checkpoint: {x}")79p.override_settings['sd_model_checkpoint'] = info.name808182def confirm_checkpoints(p, xs):83for x in xs:84if modules.sd_models.get_closet_checkpoint_match(x) is None:85raise RuntimeError(f"Unknown checkpoint: {x}")868788def confirm_checkpoints_or_none(p, xs):89for x in xs:90if x in (None, "", "None", "none"):91continue9293if modules.sd_models.get_closet_checkpoint_match(x) is None:94raise RuntimeError(f"Unknown checkpoint: {x}")959697def confirm_range(min_val, max_val, axis_label):98"""Generates a AxisOption.confirm() function that checks all values are within the specified range."""99100def confirm_range_fun(p, xs):101for x in xs:102if not (max_val >= x >= min_val):103raise ValueError(f'{axis_label} value "{x}" out of range [{min_val}, {max_val}]')104105return confirm_range_fun106107108def apply_size(p, x: str, xs) -> None:109try:110width, _, height = x.partition('x')111width = int(width.strip())112height = int(height.strip())113p.width = width114p.height = height115except ValueError:116print(f"Invalid size in XYZ plot: {x}")117118119def find_vae(name: str):120if (name := name.strip().lower()) in ('auto', 'automatic'):121return 'Automatic'122elif name == 'none':123return 'None'124return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')125126127def apply_vae(p, x, xs):128p.override_settings['sd_vae'] = find_vae(x)129130131def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):132p.styles.extend(x.split(','))133134135def apply_uni_pc_order(p, x, xs):136p.override_settings['uni_pc_order'] = min(x, p.steps - 1)137138139def apply_face_restore(p, opt, x):140opt = opt.lower()141if opt == 'codeformer':142is_active = True143p.face_restoration_model = 'CodeFormer'144elif opt == 'gfpgan':145is_active = True146p.face_restoration_model = 'GFPGAN'147else:148is_active = opt in ('true', 'yes', 'y', '1')149150p.restore_faces = is_active151152153def apply_override(field, boolean: bool = False):154def fun(p, x, xs):155if boolean:156x = True if x.lower() == "true" else False157p.override_settings[field] = x158159return fun160161162def boolean_choice(reverse: bool = False):163def choice():164return ["False", "True"] if reverse else ["True", "False"]165166return choice167168169def format_value_add_label(p, opt, x):170if type(x) == float:171x = round(x, 8)172173return f"{opt.label}: {x}"174175176def format_value(p, opt, x):177if type(x) == float:178x = round(x, 8)179return x180181182def format_value_join_list(p, opt, x):183return ", ".join(x)184185186def do_nothing(p, x, xs):187pass188189190def format_nothing(p, opt, x):191return ""192193194def format_remove_path(p, opt, x):195return os.path.basename(x)196197198def str_permutations(x):199"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""200return x201202203def list_to_csv_string(data_list):204with StringIO() as o:205csv.writer(o).writerow(data_list)206return o.getvalue().strip()207208209def csv_string_to_list_strip(data_str):210return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str), skipinitialspace=True))))211212213class AxisOption:214def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):215self.label = label216self.type = type217self.apply = apply218self.format_value = format_value219self.confirm = confirm220self.cost = cost221self.prepare = prepare222self.choices = choices223224225class AxisOptionImg2Img(AxisOption):226def __init__(self, *args, **kwargs):227super().__init__(*args, **kwargs)228self.is_img2img = True229230231class AxisOptionTxt2Img(AxisOption):232def __init__(self, *args, **kwargs):233super().__init__(*args, **kwargs)234self.is_img2img = False235236237axis_options = [238AxisOption("Nothing", str, do_nothing, format_value=format_nothing),239AxisOption("Seed", int, apply_field("seed")),240AxisOption("Var. seed", int, apply_field("subseed")),241AxisOption("Var. strength", float, apply_field("subseed_strength")),242AxisOption("Steps", int, apply_field("steps")),243AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),244AxisOption("CFG Scale", float, apply_field("cfg_scale")),245AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),246AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),247AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),248AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]),249AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),250AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),251AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),252AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),253AxisOption("Sigma Churn", float, apply_field("s_churn")),254AxisOption("Sigma min", float, apply_field("s_tmin")),255AxisOption("Sigma max", float, apply_field("s_tmax")),256AxisOption("Sigma noise", float, apply_field("s_noise")),257AxisOption("Schedule type", str, apply_field("scheduler"), choices=lambda: [x.label for x in sd_schedulers.schedulers]),258AxisOption("Schedule min sigma", float, apply_override("sigma_min")),259AxisOption("Schedule max sigma", float, apply_override("sigma_max")),260AxisOption("Schedule rho", float, apply_override("rho")),261AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")),262AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")),263AxisOption("Eta", float, apply_field("eta")),264AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),265AxisOption("Denoising", float, apply_field("denoising_strength")),266AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),267AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),268AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),269AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),270AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),271AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),272AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),273AxisOption("Face restore", str, apply_face_restore, format_value=format_value),274AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),275AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),276AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),277AxisOption("SGM noise multiplier", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)),278AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),279AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),280AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),281AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),282AxisOption("Size", str, apply_size),283]284285286def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):287hor_texts = [[images.GridAnnotation(x)] for x in x_labels]288ver_texts = [[images.GridAnnotation(y)] for y in y_labels]289title_texts = [[images.GridAnnotation(z)] for z in z_labels]290291list_size = (len(xs) * len(ys) * len(zs))292293processed_result = None294295state.job_count = list_size * p.n_iter296297def process_cell(x, y, z, ix, iy, iz):298nonlocal processed_result299300def index(ix, iy, iz):301return ix + iy * len(xs) + iz * len(xs) * len(ys)302303state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"304305processed: Processed = cell(x, y, z, ix, iy, iz)306307if processed_result is None:308# Use our first processed result object as a template container to hold our full results309processed_result = copy(processed)310processed_result.images = [None] * list_size311processed_result.all_prompts = [None] * list_size312processed_result.all_seeds = [None] * list_size313processed_result.infotexts = [None] * list_size314processed_result.index_of_first_image = 1315316idx = index(ix, iy, iz)317if processed.images:318# Non-empty list indicates some degree of success.319processed_result.images[idx] = processed.images[0]320processed_result.all_prompts[idx] = processed.prompt321processed_result.all_seeds[idx] = processed.seed322processed_result.infotexts[idx] = processed.infotexts[0]323else:324cell_mode = "P"325cell_size = (processed_result.width, processed_result.height)326if processed_result.images[0] is not None:327cell_mode = processed_result.images[0].mode328# This corrects size in case of batches:329cell_size = processed_result.images[0].size330processed_result.images[idx] = Image.new(cell_mode, cell_size)331332if first_axes_processed == 'x':333for ix, x in enumerate(xs):334if second_axes_processed == 'y':335for iy, y in enumerate(ys):336for iz, z in enumerate(zs):337process_cell(x, y, z, ix, iy, iz)338else:339for iz, z in enumerate(zs):340for iy, y in enumerate(ys):341process_cell(x, y, z, ix, iy, iz)342elif first_axes_processed == 'y':343for iy, y in enumerate(ys):344if second_axes_processed == 'x':345for ix, x in enumerate(xs):346for iz, z in enumerate(zs):347process_cell(x, y, z, ix, iy, iz)348else:349for iz, z in enumerate(zs):350for ix, x in enumerate(xs):351process_cell(x, y, z, ix, iy, iz)352elif first_axes_processed == 'z':353for iz, z in enumerate(zs):354if second_axes_processed == 'x':355for ix, x in enumerate(xs):356for iy, y in enumerate(ys):357process_cell(x, y, z, ix, iy, iz)358else:359for iy, y in enumerate(ys):360for ix, x in enumerate(xs):361process_cell(x, y, z, ix, iy, iz)362363if not processed_result:364# Should never happen, I've only seen it on one of four open tabs and it needed to refresh.365print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")366return Processed(p, [])367elif not any(processed_result.images):368print("Unexpected error: draw_xyz_grid failed to return even a single processed image")369return Processed(p, [])370371z_count = len(zs)372373for i in range(z_count):374start_index = (i * len(xs) * len(ys)) + i375end_index = start_index + len(xs) * len(ys)376grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))377if draw_legend:378grid_max_w, grid_max_h = map(max, zip(*(img.size for img in processed_result.images[start_index:end_index])))379grid = images.draw_grid_annotations(grid, grid_max_w, grid_max_h, hor_texts, ver_texts, margin_size)380processed_result.images.insert(i, grid)381processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])382processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])383processed_result.infotexts.insert(i, processed_result.infotexts[start_index])384385z_grid = images.image_grid(processed_result.images[:z_count], rows=1)386z_sub_grid_max_w, z_sub_grid_max_h = map(max, zip(*(img.size for img in processed_result.images[:z_count])))387if draw_legend:388z_grid = images.draw_grid_annotations(z_grid, z_sub_grid_max_w, z_sub_grid_max_h, title_texts, [[images.GridAnnotation()]])389processed_result.images.insert(0, z_grid)390# TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.391# processed_result.all_prompts.insert(0, processed_result.all_prompts[0])392# processed_result.all_seeds.insert(0, processed_result.all_seeds[0])393processed_result.infotexts.insert(0, processed_result.infotexts[0])394395return processed_result396397398class SharedSettingsStackHelper(object):399def __enter__(self):400pass401402def __exit__(self, exc_type, exc_value, tb):403modules.sd_models.reload_model_weights()404modules.sd_vae.reload_vae_weights()405406407re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")408re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")409410re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")411re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")412413414class Script(scripts.Script):415def title(self):416return "X/Y/Z plot"417418def ui(self, is_img2img):419self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]420421with gr.Row():422with gr.Column(scale=19):423with gr.Row():424x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))425x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))426x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)427fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)428429with gr.Row():430y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))431y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))432y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)433fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)434435with gr.Row():436z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))437z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))438z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)439fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)440441with gr.Row(variant="compact", elem_id="axis_options"):442with gr.Column():443draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))444no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))445with gr.Row():446vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.")447vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.")448vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.")449with gr.Column():450include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))451include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))452csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))453with gr.Column():454margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))455456with gr.Row(variant="compact", elem_id="swap_axes"):457swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")458swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")459swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")460461def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):462return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown463464xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]465swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)466yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]467swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)468xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]469swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)470471def fill(axis_type, csv_mode):472axis = self.current_axis_options[axis_type]473if axis.choices:474if csv_mode:475return list_to_csv_string(axis.choices()), gr.update()476else:477return gr.update(), axis.choices()478else:479return gr.update(), gr.update()480481fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])482fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])483fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])484485def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):486axis_type = axis_type or 0 # if axle type is None set to 0487488choices = self.current_axis_options[axis_type].choices489has_choices = choices is not None490491if has_choices:492choices = choices()493if csv_mode:494if axis_values_dropdown:495axis_values = list_to_csv_string(list(filter(lambda x: x in choices, axis_values_dropdown)))496axis_values_dropdown = []497else:498if axis_values:499axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values)))500axis_values = ""501502return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values),503gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=axis_values_dropdown))504505x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])506y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])507z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])508509def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):510_fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)511_fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)512_fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)513return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown514515csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])516517def get_dropdown_update_from_params(axis, params):518val_key = f"{axis} Values"519vals = params.get(val_key, "")520valslist = csv_string_to_list_strip(vals)521return gr.update(value=valslist)522523self.infotext_fields = (524(x_type, "X Type"),525(x_values, "X Values"),526(x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),527(y_type, "Y Type"),528(y_values, "Y Values"),529(y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),530(z_type, "Z Type"),531(z_values, "Z Values"),532(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),533)534535return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode]536537def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode):538x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0539540if not no_fixed_seeds:541modules.processing.fix_seed(p)542543if not opts.return_grid:544p.batch_size = 1545546def process_axis(opt, vals, vals_dropdown):547if opt.label == 'Nothing':548return [0]549550if opt.choices is not None and not csv_mode:551valslist = vals_dropdown552elif opt.prepare is not None:553valslist = opt.prepare(vals)554else:555valslist = csv_string_to_list_strip(vals)556557if opt.type == int:558valslist_ext = []559560for val in valslist:561if val.strip() == '':562continue563m = re_range.fullmatch(val)564mc = re_range_count.fullmatch(val)565if m is not None:566start = int(m.group(1))567end = int(m.group(2)) + 1568step = int(m.group(3)) if m.group(3) is not None else 1569570valslist_ext += list(range(start, end, step))571elif mc is not None:572start = int(mc.group(1))573end = int(mc.group(2))574num = int(mc.group(3)) if mc.group(3) is not None else 1575576valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]577else:578valslist_ext.append(val)579580valslist = valslist_ext581elif opt.type == float:582valslist_ext = []583584for val in valslist:585if val.strip() == '':586continue587m = re_range_float.fullmatch(val)588mc = re_range_count_float.fullmatch(val)589if m is not None:590start = float(m.group(1))591end = float(m.group(2))592step = float(m.group(3)) if m.group(3) is not None else 1593594valslist_ext += np.arange(start, end + step, step).tolist()595elif mc is not None:596start = float(mc.group(1))597end = float(mc.group(2))598num = int(mc.group(3)) if mc.group(3) is not None else 1599600valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()601else:602valslist_ext.append(val)603604valslist = valslist_ext605elif opt.type == str_permutations:606valslist = list(permutations(valslist))607608valslist = [opt.type(x) for x in valslist]609610# Confirm options are valid before starting611if opt.confirm:612opt.confirm(p, valslist)613614return valslist615616x_opt = self.current_axis_options[x_type]617if x_opt.choices is not None and not csv_mode:618x_values = list_to_csv_string(x_values_dropdown)619xs = process_axis(x_opt, x_values, x_values_dropdown)620621y_opt = self.current_axis_options[y_type]622if y_opt.choices is not None and not csv_mode:623y_values = list_to_csv_string(y_values_dropdown)624ys = process_axis(y_opt, y_values, y_values_dropdown)625626z_opt = self.current_axis_options[z_type]627if z_opt.choices is not None and not csv_mode:628z_values = list_to_csv_string(z_values_dropdown)629zs = process_axis(z_opt, z_values, z_values_dropdown)630631# this could be moved to common code, but unlikely to be ever triggered anywhere else632Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes633grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)634assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'635636def fix_axis_seeds(axis_opt, axis_list):637if axis_opt.label in ['Seed', 'Var. seed']:638return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]639else:640return axis_list641642if not no_fixed_seeds:643xs = fix_axis_seeds(x_opt, xs)644ys = fix_axis_seeds(y_opt, ys)645zs = fix_axis_seeds(z_opt, zs)646647if x_opt.label == 'Steps':648total_steps = sum(xs) * len(ys) * len(zs)649elif y_opt.label == 'Steps':650total_steps = sum(ys) * len(xs) * len(zs)651elif z_opt.label == 'Steps':652total_steps = sum(zs) * len(xs) * len(ys)653else:654total_steps = p.steps * len(xs) * len(ys) * len(zs)655656if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:657if x_opt.label == "Hires steps":658total_steps += sum(xs) * len(ys) * len(zs)659elif y_opt.label == "Hires steps":660total_steps += sum(ys) * len(xs) * len(zs)661elif z_opt.label == "Hires steps":662total_steps += sum(zs) * len(xs) * len(ys)663elif p.hr_second_pass_steps:664total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)665else:666total_steps *= 2667668total_steps *= p.n_iter669670image_cell_count = p.n_iter * p.batch_size671cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""672plural_s = 's' if len(zs) > 1 else ''673print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")674shared.total_tqdm.updateTotal(total_steps)675676state.xyz_plot_x = AxisInfo(x_opt, xs)677state.xyz_plot_y = AxisInfo(y_opt, ys)678state.xyz_plot_z = AxisInfo(z_opt, zs)679680# If one of the axes is very slow to change between (like SD model681# checkpoint), then make sure it is in the outer iteration of the nested682# `for` loop.683first_axes_processed = 'z'684second_axes_processed = 'y'685if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:686first_axes_processed = 'x'687if y_opt.cost > z_opt.cost:688second_axes_processed = 'y'689else:690second_axes_processed = 'z'691elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:692first_axes_processed = 'y'693if x_opt.cost > z_opt.cost:694second_axes_processed = 'x'695else:696second_axes_processed = 'z'697elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:698first_axes_processed = 'z'699if x_opt.cost > y_opt.cost:700second_axes_processed = 'x'701else:702second_axes_processed = 'y'703704grid_infotext = [None] * (1 + len(zs))705706def cell(x, y, z, ix, iy, iz):707if shared.state.interrupted or state.stopping_generation:708return Processed(p, [], p.seed, "")709710pc = copy(p)711pc.styles = pc.styles[:]712x_opt.apply(pc, x, xs)713y_opt.apply(pc, y, ys)714z_opt.apply(pc, z, zs)715716xdim = len(xs) if vary_seeds_x else 1717ydim = len(ys) if vary_seeds_y else 1718719if vary_seeds_x:720pc.seed += ix721if vary_seeds_y:722pc.seed += iy * xdim723if vary_seeds_z:724pc.seed += iz * xdim * ydim725726try:727res = process_images(pc)728except Exception as e:729errors.display(e, "generating image for xyz plot")730731res = Processed(p, [], p.seed, "")732733# Sets subgrid infotexts734subgrid_index = 1 + iz735if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:736pc.extra_generation_params = copy(pc.extra_generation_params)737pc.extra_generation_params['Script'] = self.title()738739if x_opt.label != 'Nothing':740pc.extra_generation_params["X Type"] = x_opt.label741pc.extra_generation_params["X Values"] = x_values742if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:743pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])744745if y_opt.label != 'Nothing':746pc.extra_generation_params["Y Type"] = y_opt.label747pc.extra_generation_params["Y Values"] = y_values748if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:749pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])750751grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)752753# Sets main grid infotext754if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:755pc.extra_generation_params = copy(pc.extra_generation_params)756757if z_opt.label != 'Nothing':758pc.extra_generation_params["Z Type"] = z_opt.label759pc.extra_generation_params["Z Values"] = z_values760if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:761pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])762763grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)764765return res766767with SharedSettingsStackHelper():768processed = draw_xyz_grid(769p,770xs=xs,771ys=ys,772zs=zs,773x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],774y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],775z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],776cell=cell,777draw_legend=draw_legend,778include_lone_images=include_lone_images,779include_sub_grids=include_sub_grids,780first_axes_processed=first_axes_processed,781second_axes_processed=second_axes_processed,782margin_size=margin_size783)784785if not processed.images:786# It broke, no further handling needed.787return processed788789z_count = len(zs)790791# Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)792processed.infotexts[:1 + z_count] = grid_infotext[:1 + z_count]793794if not include_lone_images:795# Don't need sub-images anymore, drop from list:796processed.images = processed.images[:z_count + 1]797798if opts.grid_save:799# Auto-save main and sub-grids:800grid_count = z_count + 1 if z_count > 1 else 1801for g in range(grid_count):802# TODO: See previous comment about intentional data misalignment.803adj_g = g - 1 if g > 0 else g804images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)805if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid806break807808if not include_sub_grids:809# Done with sub-grids, drop all related information:810for _ in range(z_count):811del processed.images[1]812del processed.all_prompts[1]813del processed.all_seeds[1]814del processed.infotexts[1]815816return processed817818819