Path: blob/master/modules/api/api.py
2447 views
import base641import io2import os3import time4import datetime5import uvicorn6import ipaddress7import requests8import gradio as gr9from threading import Lock10from io import BytesIO11from fastapi import APIRouter, Depends, FastAPI, Request, Response12from fastapi.security import HTTPBasic, HTTPBasicCredentials13from fastapi.exceptions import HTTPException14from fastapi.responses import JSONResponse15from fastapi.encoders import jsonable_encoder16from secrets import compare_digest1718import modules.shared as shared19from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers20from modules.api import models21from modules.shared import opts22from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images23from modules.textual_inversion.textual_inversion import create_embedding, train_embedding24from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork25from PIL import PngImagePlugin26from modules.sd_models_config import find_checkpoint_config_near_filename27from modules.realesrgan_model import get_realesrgan_models28from modules import devices29from typing import Any30import piexif31import piexif.helper32from contextlib import closing33from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task3435def script_name_to_index(name, scripts):36try:37return [script.title().lower() for script in scripts].index(name.lower())38except Exception as e:39raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e404142def validate_sampler_name(name):43config = sd_samplers.all_samplers_map.get(name, None)44if config is None:45raise HTTPException(status_code=400, detail="Sampler not found")4647return name484950def setUpscalers(req: dict):51reqDict = vars(req)52reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)53reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)54return reqDict555657def verify_url(url):58"""Returns True if the url refers to a global resource."""5960import socket61from urllib.parse import urlparse62try:63parsed_url = urlparse(url)64domain_name = parsed_url.netloc65host = socket.gethostbyname_ex(domain_name)66for ip in host[2]:67ip_addr = ipaddress.ip_address(ip)68if not ip_addr.is_global:69return False70except Exception:71return False7273return True747576def decode_base64_to_image(encoding):77if encoding.startswith("http://") or encoding.startswith("https://"):78if not opts.api_enable_requests:79raise HTTPException(status_code=500, detail="Requests not allowed")8081if opts.api_forbid_local_requests and not verify_url(encoding):82raise HTTPException(status_code=500, detail="Request to local resource not allowed")8384headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}85response = requests.get(encoding, timeout=30, headers=headers)86try:87image = images.read(BytesIO(response.content))88return image89except Exception as e:90raise HTTPException(status_code=500, detail="Invalid image url") from e9192if encoding.startswith("data:image/"):93encoding = encoding.split(";")[1].split(",")[1]94try:95image = images.read(BytesIO(base64.b64decode(encoding)))96return image97except Exception as e:98raise HTTPException(status_code=500, detail="Invalid encoded image") from e99100101def encode_pil_to_base64(image):102with io.BytesIO() as output_bytes:103if isinstance(image, str):104return image105if opts.samples_format.lower() == 'png':106use_metadata = False107metadata = PngImagePlugin.PngInfo()108for key, value in image.info.items():109if isinstance(key, str) and isinstance(value, str):110metadata.add_text(key, value)111use_metadata = True112image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)113114elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):115if image.mode in ("RGBA", "P"):116image = image.convert("RGB")117parameters = image.info.get('parameters', None)118exif_bytes = piexif.dump({119"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }120})121if opts.samples_format.lower() in ("jpg", "jpeg"):122image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)123else:124image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)125126else:127raise HTTPException(status_code=500, detail="Invalid image format")128129bytes_data = output_bytes.getvalue()130131return base64.b64encode(bytes_data)132133134def api_middleware(app: FastAPI):135rich_available = False136try:137if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:138import anyio # importing just so it can be placed on silent list139import starlette # importing just so it can be placed on silent list140from rich.console import Console141console = Console()142rich_available = True143except Exception:144pass145146@app.middleware("http")147async def log_and_time(req: Request, call_next):148ts = time.time()149res: Response = await call_next(req)150duration = str(round(time.time() - ts, 4))151res.headers["X-Process-Time"] = duration152endpoint = req.scope.get('path', 'err')153if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):154print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(155t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),156code=res.status_code,157ver=req.scope.get('http_version', '0.0'),158cli=req.scope.get('client', ('0:0.0.0', 0))[0],159prot=req.scope.get('scheme', 'err'),160method=req.scope.get('method', 'err'),161endpoint=endpoint,162duration=duration,163))164return res165166def handle_exception(request: Request, e: Exception):167err = {168"error": type(e).__name__,169"detail": vars(e).get('detail', ''),170"body": vars(e).get('body', ''),171"errors": str(e),172}173if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions174message = f"API error: {request.method}: {request.url} {err}"175if rich_available:176print(message)177console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))178else:179errors.report(message, exc_info=True)180return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))181182@app.middleware("http")183async def exception_handling(request: Request, call_next):184try:185return await call_next(request)186except Exception as e:187return handle_exception(request, e)188189@app.exception_handler(Exception)190async def fastapi_exception_handler(request: Request, e: Exception):191return handle_exception(request, e)192193@app.exception_handler(HTTPException)194async def http_exception_handler(request: Request, e: HTTPException):195return handle_exception(request, e)196197198class Api:199def __init__(self, app: FastAPI, queue_lock: Lock):200if shared.cmd_opts.api_auth:201self.credentials = {}202for auth in shared.cmd_opts.api_auth.split(","):203user, password = auth.split(":")204self.credentials[user] = password205206self.router = APIRouter()207self.app = app208self.queue_lock = queue_lock209api_middleware(self.app)210self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)211self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)212self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)213self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)214self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)215self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)216self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])217self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])218self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])219self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)220self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])221self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)222self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])223self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])224self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])225self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])226self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])227self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])228self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])229self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])230self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])231self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])232self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)233self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])234self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])235self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])236self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)237self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)238self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)239self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)240self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)241self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])242self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])243self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)244self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])245self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])246247if shared.cmd_opts.api_server_stop:248self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])249self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])250self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])251252self.default_script_arg_txt2img = []253self.default_script_arg_img2img = []254255txt2img_script_runner = scripts.scripts_txt2img256img2img_script_runner = scripts.scripts_img2img257258if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:259ui.create_ui()260261if not txt2img_script_runner.scripts:262txt2img_script_runner.initialize_scripts(False)263if not self.default_script_arg_txt2img:264self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)265266if not img2img_script_runner.scripts:267img2img_script_runner.initialize_scripts(True)268if not self.default_script_arg_img2img:269self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)270271272273def add_api_route(self, path: str, endpoint, **kwargs):274if shared.cmd_opts.api_auth:275return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)276return self.app.add_api_route(path, endpoint, **kwargs)277278def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):279if credentials.username in self.credentials:280if compare_digest(credentials.password, self.credentials[credentials.username]):281return True282283raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})284285def get_selectable_script(self, script_name, script_runner):286if script_name is None or script_name == "":287return None, None288289script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)290script = script_runner.selectable_scripts[script_idx]291return script, script_idx292293def get_scripts_list(self):294t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]295i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]296297return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)298299def get_script_info(self):300res = []301302for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:303res += [script.api_info for script in script_list if script.api_info is not None]304305return res306307def get_script(self, script_name, script_runner):308if script_name is None or script_name == "":309return None, None310311script_idx = script_name_to_index(script_name, script_runner.scripts)312return script_runner.scripts[script_idx]313314def init_default_script_args(self, script_runner):315#find max idx from the scripts in runner and generate a none array to init script_args316last_arg_index = 1317for script in script_runner.scripts:318if last_arg_index < script.args_to:319last_arg_index = script.args_to320# None everywhere except position 0 to initialize script args321script_args = [None]*last_arg_index322script_args[0] = 0323324# get default values325with gr.Blocks(): # will throw errors calling ui function without this326for script in script_runner.scripts:327if script.ui(script.is_img2img):328ui_default_values = []329for elem in script.ui(script.is_img2img):330ui_default_values.append(elem.value)331script_args[script.args_from:script.args_to] = ui_default_values332return script_args333334def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):335script_args = default_script_args.copy()336337if input_script_args is not None:338for index, value in input_script_args.items():339script_args[index] = value340341# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()342if selectable_scripts:343script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args344script_args[0] = selectable_idx + 1345346# Now check for always on scripts347if request.alwayson_scripts:348for alwayson_script_name in request.alwayson_scripts.keys():349alwayson_script = self.get_script(alwayson_script_name, script_runner)350if alwayson_script is None:351raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")352# Selectable script in always on script param check353if alwayson_script.alwayson is False:354raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")355# always on script with no arg should always run so you don't really need to add them to the requests356if "args" in request.alwayson_scripts[alwayson_script_name]:357# min between arg length in scriptrunner and arg length in the request358for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):359script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]360return script_args361362def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):363"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.364365If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.366367Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.368"""369370if not request.infotext:371return {}372373possible_fields = infotext_utils.paste_fields[tabname]["fields"]374set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this375params = infotext_utils.parse_generation_parameters(request.infotext)376377def get_field_value(field, params):378value = field.function(params) if field.function else params.get(field.label)379if value is None:380return None381382if field.api in request.__fields__:383target_type = request.__fields__[field.api].type_384else:385target_type = type(field.component.value)386387if target_type == type(None):388return None389390if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value391value = value.get('value')392393if value is not None and not isinstance(value, target_type):394value = target_type(value)395396return value397398for field in possible_fields:399if not field.api:400continue401402if field.api in set_fields:403continue404405value = get_field_value(field, params)406if value is not None:407setattr(request, field.api, value)408409if request.override_settings is None:410request.override_settings = {}411412overridden_settings = infotext_utils.get_override_settings(params)413for _, setting_name, value in overridden_settings:414if setting_name not in request.override_settings:415request.override_settings[setting_name] = value416417if script_runner is not None and mentioned_script_args is not None:418indexes = {v: i for i, v in enumerate(script_runner.inputs)}419script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)420421for field, index in script_fields:422value = get_field_value(field, params)423424if value is None:425continue426427mentioned_script_args[index] = value428429return params430431def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):432task_id = txt2imgreq.force_task_id or create_task_id("txt2img")433434script_runner = scripts.scripts_txt2img435436infotext_script_args = {}437self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)438439selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)440sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)441442populate = txt2imgreq.copy(update={ # Override __init__ params443"sampler_name": validate_sampler_name(sampler),444"do_not_save_samples": not txt2imgreq.save_images,445"do_not_save_grid": not txt2imgreq.save_images,446})447if populate.sampler_name:448populate.sampler_index = None # prevent a warning later on449450if not populate.scheduler and scheduler != "Automatic":451populate.scheduler = scheduler452453args = vars(populate)454args.pop('script_name', None)455args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them456args.pop('alwayson_scripts', None)457args.pop('infotext', None)458459script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)460461send_images = args.pop('send_images', True)462args.pop('save_images', None)463464add_task_to_queue(task_id)465466with self.queue_lock:467with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:468p.is_api = True469p.scripts = script_runner470p.outpath_grids = opts.outdir_txt2img_grids471p.outpath_samples = opts.outdir_txt2img_samples472473try:474shared.state.begin(job="scripts_txt2img")475start_task(task_id)476if selectable_scripts is not None:477p.script_args = script_args478processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here479else:480p.script_args = tuple(script_args) # Need to pass args as tuple here481processed = process_images(p)482finish_task(task_id)483finally:484shared.state.end()485shared.total_tqdm.clear()486487b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []488489return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())490491def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):492task_id = img2imgreq.force_task_id or create_task_id("img2img")493494init_images = img2imgreq.init_images495if init_images is None:496raise HTTPException(status_code=404, detail="Init image not found")497498mask = img2imgreq.mask499if mask:500mask = decode_base64_to_image(mask)501502script_runner = scripts.scripts_img2img503504infotext_script_args = {}505self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)506507selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)508sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)509510populate = img2imgreq.copy(update={ # Override __init__ params511"sampler_name": validate_sampler_name(sampler),512"do_not_save_samples": not img2imgreq.save_images,513"do_not_save_grid": not img2imgreq.save_images,514"mask": mask,515})516if populate.sampler_name:517populate.sampler_index = None # prevent a warning later on518519if not populate.scheduler and scheduler != "Automatic":520populate.scheduler = scheduler521522args = vars(populate)523args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.524args.pop('script_name', None)525args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them526args.pop('alwayson_scripts', None)527args.pop('infotext', None)528529script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)530531send_images = args.pop('send_images', True)532args.pop('save_images', None)533534add_task_to_queue(task_id)535536with self.queue_lock:537with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:538p.init_images = [decode_base64_to_image(x) for x in init_images]539p.is_api = True540p.scripts = script_runner541p.outpath_grids = opts.outdir_img2img_grids542p.outpath_samples = opts.outdir_img2img_samples543544try:545shared.state.begin(job="scripts_img2img")546start_task(task_id)547if selectable_scripts is not None:548p.script_args = script_args549processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here550else:551p.script_args = tuple(script_args) # Need to pass args as tuple here552processed = process_images(p)553finish_task(task_id)554finally:555shared.state.end()556shared.total_tqdm.clear()557558b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []559560if not img2imgreq.include_init_images:561img2imgreq.init_images = None562img2imgreq.mask = None563564return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())565566def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):567reqDict = setUpscalers(req)568569reqDict['image'] = decode_base64_to_image(reqDict['image'])570571with self.queue_lock:572result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)573574return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])575576def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):577reqDict = setUpscalers(req)578579image_list = reqDict.pop('imageList', [])580image_folder = [decode_base64_to_image(x.data) for x in image_list]581582with self.queue_lock:583result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)584585return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])586587def pnginfoapi(self, req: models.PNGInfoRequest):588image = decode_base64_to_image(req.image.strip())589if image is None:590return models.PNGInfoResponse(info="")591592geninfo, items = images.read_info_from_image(image)593if geninfo is None:594geninfo = ""595596params = infotext_utils.parse_generation_parameters(geninfo)597script_callbacks.infotext_pasted_callback(geninfo, params)598599return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)600601def progressapi(self, req: models.ProgressRequest = Depends()):602# copy from check_progress_call of ui.py603604if shared.state.job_count == 0:605return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)606607# avoid dividing zero608progress = 0.01609610if shared.state.job_count > 0:611progress += shared.state.job_no / shared.state.job_count612if shared.state.sampling_steps > 0:613progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps614615time_since_start = time.time() - shared.state.time_start616eta = (time_since_start/progress)617eta_relative = eta-time_since_start618619progress = min(progress, 1)620621shared.state.set_current_image()622623current_image = None624if shared.state.current_image and not req.skip_current_image:625current_image = encode_pil_to_base64(shared.state.current_image)626627return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)628629def interrogateapi(self, interrogatereq: models.InterrogateRequest):630image_b64 = interrogatereq.image631if image_b64 is None:632raise HTTPException(status_code=404, detail="Image not found")633634img = decode_base64_to_image(image_b64)635img = img.convert('RGB')636637# Override object param638with self.queue_lock:639if interrogatereq.model == "clip":640processed = shared.interrogator.interrogate(img)641elif interrogatereq.model == "deepdanbooru":642processed = deepbooru.model.tag(img)643else:644raise HTTPException(status_code=404, detail="Model not found")645646return models.InterrogateResponse(caption=processed)647648def interruptapi(self):649shared.state.interrupt()650651return {}652653def unloadapi(self):654sd_models.unload_model_weights()655656return {}657658def reloadapi(self):659sd_models.send_model_to_device(shared.sd_model)660661return {}662663def skip(self):664shared.state.skip()665666def get_config(self):667options = {}668for key in shared.opts.data.keys():669metadata = shared.opts.data_labels.get(key)670if(metadata is not None):671options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})672else:673options.update({key: shared.opts.data.get(key, None)})674675return options676677def set_config(self, req: dict[str, Any]):678checkpoint_name = req.get("sd_model_checkpoint", None)679if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:680raise RuntimeError(f"model {checkpoint_name!r} not found")681682for k, v in req.items():683shared.opts.set(k, v, is_api=True)684685shared.opts.save(shared.config_filename)686return687688def get_cmd_flags(self):689return vars(shared.cmd_opts)690691def get_samplers(self):692return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]693694def get_schedulers(self):695return [696{697"name": scheduler.name,698"label": scheduler.label,699"aliases": scheduler.aliases,700"default_rho": scheduler.default_rho,701"need_inner_model": scheduler.need_inner_model,702}703for scheduler in sd_schedulers.schedulers]704705def get_upscalers(self):706return [707{708"name": upscaler.name,709"model_name": upscaler.scaler.model_name,710"model_path": upscaler.data_path,711"model_url": None,712"scale": upscaler.scale,713}714for upscaler in shared.sd_upscalers715]716717def get_latent_upscale_modes(self):718return [719{720"name": upscale_mode,721}722for upscale_mode in [*(shared.latent_upscale_modes or {})]723]724725def get_sd_models(self):726import modules.sd_models as sd_models727return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]728729def get_sd_vaes(self):730import modules.sd_vae as sd_vae731return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]732733def get_hypernetworks(self):734return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]735736def get_face_restorers(self):737return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]738739def get_realesrgan_models(self):740return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]741742def get_prompt_styles(self):743styleList = []744for k in shared.prompt_styles.styles:745style = shared.prompt_styles.styles[k]746styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})747748return styleList749750def get_embeddings(self):751db = sd_hijack.model_hijack.embedding_db752753def convert_embedding(embedding):754return {755"step": embedding.step,756"sd_checkpoint": embedding.sd_checkpoint,757"sd_checkpoint_name": embedding.sd_checkpoint_name,758"shape": embedding.shape,759"vectors": embedding.vectors,760}761762def convert_embeddings(embeddings):763return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}764765return {766"loaded": convert_embeddings(db.word_embeddings),767"skipped": convert_embeddings(db.skipped_embeddings),768}769770def refresh_embeddings(self):771with self.queue_lock:772sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)773774def refresh_checkpoints(self):775with self.queue_lock:776shared.refresh_checkpoints()777778def refresh_vae(self):779with self.queue_lock:780shared_items.refresh_vae_list()781782def create_embedding(self, args: dict):783try:784shared.state.begin(job="create_embedding")785filename = create_embedding(**args) # create empty embedding786sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used787return models.CreateResponse(info=f"create embedding filename: {filename}")788except AssertionError as e:789return models.TrainResponse(info=f"create embedding error: {e}")790finally:791shared.state.end()792793794def create_hypernetwork(self, args: dict):795try:796shared.state.begin(job="create_hypernetwork")797filename = create_hypernetwork(**args) # create empty embedding798return models.CreateResponse(info=f"create hypernetwork filename: {filename}")799except AssertionError as e:800return models.TrainResponse(info=f"create hypernetwork error: {e}")801finally:802shared.state.end()803804def train_embedding(self, args: dict):805try:806shared.state.begin(job="train_embedding")807apply_optimizations = shared.opts.training_xattention_optimizations808error = None809filename = ''810if not apply_optimizations:811sd_hijack.undo_optimizations()812try:813embedding, filename = train_embedding(**args) # can take a long time to complete814except Exception as e:815error = e816finally:817if not apply_optimizations:818sd_hijack.apply_optimizations()819return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")820except Exception as msg:821return models.TrainResponse(info=f"train embedding error: {msg}")822finally:823shared.state.end()824825def train_hypernetwork(self, args: dict):826try:827shared.state.begin(job="train_hypernetwork")828shared.loaded_hypernetworks = []829apply_optimizations = shared.opts.training_xattention_optimizations830error = None831filename = ''832if not apply_optimizations:833sd_hijack.undo_optimizations()834try:835hypernetwork, filename = train_hypernetwork(**args)836except Exception as e:837error = e838finally:839shared.sd_model.cond_stage_model.to(devices.device)840shared.sd_model.first_stage_model.to(devices.device)841if not apply_optimizations:842sd_hijack.apply_optimizations()843shared.state.end()844return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")845except Exception as exc:846return models.TrainResponse(info=f"train embedding error: {exc}")847finally:848shared.state.end()849850def get_memory(self):851try:852import os853import psutil854process = psutil.Process(os.getpid())855res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values856ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe857ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }858except Exception as err:859ram = { 'error': f'{err}' }860try:861import torch862if torch.cuda.is_available():863s = torch.cuda.mem_get_info()864system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }865s = dict(torch.cuda.memory_stats(shared.device))866allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }867reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }868active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }869inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }870warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }871cuda = {872'system': system,873'active': active,874'allocated': allocated,875'reserved': reserved,876'inactive': inactive,877'events': warnings,878}879else:880cuda = {'error': 'unavailable'}881except Exception as err:882cuda = {'error': f'{err}'}883return models.MemoryResponse(ram=ram, cuda=cuda)884885def get_extensions_list(self):886from modules import extensions887extensions.list_extensions()888ext_list = []889for ext in extensions.extensions:890ext: extensions.Extension891ext.read_info_from_repo()892if ext.remote is not None:893ext_list.append({894"name": ext.name,895"remote": ext.remote,896"branch": ext.branch,897"commit_hash":ext.commit_hash,898"commit_date":ext.commit_date,899"version":ext.version,900"enabled":ext.enabled901})902return ext_list903904def launch(self, server_name, port, root_path):905self.app.include_router(self.router)906uvicorn.run(907self.app,908host=server_name,909port=port,910timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,911root_path=root_path,912ssl_keyfile=shared.cmd_opts.tls_keyfile,913ssl_certfile=shared.cmd_opts.tls_certfile914)915916def kill_webui(self):917restart.stop_program()918919def restart_webui(self):920if restart.is_restartable():921restart.restart_program()922return Response(status_code=501)923924def stop_webui(request):925shared.state.server_command = "stop"926return Response("Stopping.")927928929930