Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AUTOMATIC1111
GitHub Repository: AUTOMATIC1111/stable-diffusion-webui
Path: blob/master/modules/api/api.py
2447 views
1
import base64
2
import io
3
import os
4
import time
5
import datetime
6
import uvicorn
7
import ipaddress
8
import requests
9
import gradio as gr
10
from threading import Lock
11
from io import BytesIO
12
from fastapi import APIRouter, Depends, FastAPI, Request, Response
13
from fastapi.security import HTTPBasic, HTTPBasicCredentials
14
from fastapi.exceptions import HTTPException
15
from fastapi.responses import JSONResponse
16
from fastapi.encoders import jsonable_encoder
17
from secrets import compare_digest
18
19
import modules.shared as shared
20
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
21
from modules.api import models
22
from modules.shared import opts
23
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
24
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
25
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
26
from PIL import PngImagePlugin
27
from modules.sd_models_config import find_checkpoint_config_near_filename
28
from modules.realesrgan_model import get_realesrgan_models
29
from modules import devices
30
from typing import Any
31
import piexif
32
import piexif.helper
33
from contextlib import closing
34
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
35
36
def script_name_to_index(name, scripts):
37
try:
38
return [script.title().lower() for script in scripts].index(name.lower())
39
except Exception as e:
40
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
41
42
43
def validate_sampler_name(name):
44
config = sd_samplers.all_samplers_map.get(name, None)
45
if config is None:
46
raise HTTPException(status_code=400, detail="Sampler not found")
47
48
return name
49
50
51
def setUpscalers(req: dict):
52
reqDict = vars(req)
53
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
54
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
55
return reqDict
56
57
58
def verify_url(url):
59
"""Returns True if the url refers to a global resource."""
60
61
import socket
62
from urllib.parse import urlparse
63
try:
64
parsed_url = urlparse(url)
65
domain_name = parsed_url.netloc
66
host = socket.gethostbyname_ex(domain_name)
67
for ip in host[2]:
68
ip_addr = ipaddress.ip_address(ip)
69
if not ip_addr.is_global:
70
return False
71
except Exception:
72
return False
73
74
return True
75
76
77
def decode_base64_to_image(encoding):
78
if encoding.startswith("http://") or encoding.startswith("https://"):
79
if not opts.api_enable_requests:
80
raise HTTPException(status_code=500, detail="Requests not allowed")
81
82
if opts.api_forbid_local_requests and not verify_url(encoding):
83
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
84
85
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
86
response = requests.get(encoding, timeout=30, headers=headers)
87
try:
88
image = images.read(BytesIO(response.content))
89
return image
90
except Exception as e:
91
raise HTTPException(status_code=500, detail="Invalid image url") from e
92
93
if encoding.startswith("data:image/"):
94
encoding = encoding.split(";")[1].split(",")[1]
95
try:
96
image = images.read(BytesIO(base64.b64decode(encoding)))
97
return image
98
except Exception as e:
99
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
100
101
102
def encode_pil_to_base64(image):
103
with io.BytesIO() as output_bytes:
104
if isinstance(image, str):
105
return image
106
if opts.samples_format.lower() == 'png':
107
use_metadata = False
108
metadata = PngImagePlugin.PngInfo()
109
for key, value in image.info.items():
110
if isinstance(key, str) and isinstance(value, str):
111
metadata.add_text(key, value)
112
use_metadata = True
113
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
114
115
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
116
if image.mode in ("RGBA", "P"):
117
image = image.convert("RGB")
118
parameters = image.info.get('parameters', None)
119
exif_bytes = piexif.dump({
120
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
121
})
122
if opts.samples_format.lower() in ("jpg", "jpeg"):
123
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
124
else:
125
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
126
127
else:
128
raise HTTPException(status_code=500, detail="Invalid image format")
129
130
bytes_data = output_bytes.getvalue()
131
132
return base64.b64encode(bytes_data)
133
134
135
def api_middleware(app: FastAPI):
136
rich_available = False
137
try:
138
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
139
import anyio # importing just so it can be placed on silent list
140
import starlette # importing just so it can be placed on silent list
141
from rich.console import Console
142
console = Console()
143
rich_available = True
144
except Exception:
145
pass
146
147
@app.middleware("http")
148
async def log_and_time(req: Request, call_next):
149
ts = time.time()
150
res: Response = await call_next(req)
151
duration = str(round(time.time() - ts, 4))
152
res.headers["X-Process-Time"] = duration
153
endpoint = req.scope.get('path', 'err')
154
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
155
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
156
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
157
code=res.status_code,
158
ver=req.scope.get('http_version', '0.0'),
159
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
160
prot=req.scope.get('scheme', 'err'),
161
method=req.scope.get('method', 'err'),
162
endpoint=endpoint,
163
duration=duration,
164
))
165
return res
166
167
def handle_exception(request: Request, e: Exception):
168
err = {
169
"error": type(e).__name__,
170
"detail": vars(e).get('detail', ''),
171
"body": vars(e).get('body', ''),
172
"errors": str(e),
173
}
174
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
175
message = f"API error: {request.method}: {request.url} {err}"
176
if rich_available:
177
print(message)
178
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
179
else:
180
errors.report(message, exc_info=True)
181
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
182
183
@app.middleware("http")
184
async def exception_handling(request: Request, call_next):
185
try:
186
return await call_next(request)
187
except Exception as e:
188
return handle_exception(request, e)
189
190
@app.exception_handler(Exception)
191
async def fastapi_exception_handler(request: Request, e: Exception):
192
return handle_exception(request, e)
193
194
@app.exception_handler(HTTPException)
195
async def http_exception_handler(request: Request, e: HTTPException):
196
return handle_exception(request, e)
197
198
199
class Api:
200
def __init__(self, app: FastAPI, queue_lock: Lock):
201
if shared.cmd_opts.api_auth:
202
self.credentials = {}
203
for auth in shared.cmd_opts.api_auth.split(","):
204
user, password = auth.split(":")
205
self.credentials[user] = password
206
207
self.router = APIRouter()
208
self.app = app
209
self.queue_lock = queue_lock
210
api_middleware(self.app)
211
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
212
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
213
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
214
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
215
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
216
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
217
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
218
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
219
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
220
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
221
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
222
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
223
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
224
self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
225
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
226
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
227
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
228
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
229
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
230
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
231
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
232
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
233
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
234
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
235
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
236
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
237
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
238
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
239
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
240
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
241
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
242
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
243
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
244
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
245
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
246
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
247
248
if shared.cmd_opts.api_server_stop:
249
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
250
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
251
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
252
253
self.default_script_arg_txt2img = []
254
self.default_script_arg_img2img = []
255
256
txt2img_script_runner = scripts.scripts_txt2img
257
img2img_script_runner = scripts.scripts_img2img
258
259
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
260
ui.create_ui()
261
262
if not txt2img_script_runner.scripts:
263
txt2img_script_runner.initialize_scripts(False)
264
if not self.default_script_arg_txt2img:
265
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
266
267
if not img2img_script_runner.scripts:
268
img2img_script_runner.initialize_scripts(True)
269
if not self.default_script_arg_img2img:
270
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
271
272
273
274
def add_api_route(self, path: str, endpoint, **kwargs):
275
if shared.cmd_opts.api_auth:
276
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
277
return self.app.add_api_route(path, endpoint, **kwargs)
278
279
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
280
if credentials.username in self.credentials:
281
if compare_digest(credentials.password, self.credentials[credentials.username]):
282
return True
283
284
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
285
286
def get_selectable_script(self, script_name, script_runner):
287
if script_name is None or script_name == "":
288
return None, None
289
290
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
291
script = script_runner.selectable_scripts[script_idx]
292
return script, script_idx
293
294
def get_scripts_list(self):
295
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
296
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
297
298
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
299
300
def get_script_info(self):
301
res = []
302
303
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
304
res += [script.api_info for script in script_list if script.api_info is not None]
305
306
return res
307
308
def get_script(self, script_name, script_runner):
309
if script_name is None or script_name == "":
310
return None, None
311
312
script_idx = script_name_to_index(script_name, script_runner.scripts)
313
return script_runner.scripts[script_idx]
314
315
def init_default_script_args(self, script_runner):
316
#find max idx from the scripts in runner and generate a none array to init script_args
317
last_arg_index = 1
318
for script in script_runner.scripts:
319
if last_arg_index < script.args_to:
320
last_arg_index = script.args_to
321
# None everywhere except position 0 to initialize script args
322
script_args = [None]*last_arg_index
323
script_args[0] = 0
324
325
# get default values
326
with gr.Blocks(): # will throw errors calling ui function without this
327
for script in script_runner.scripts:
328
if script.ui(script.is_img2img):
329
ui_default_values = []
330
for elem in script.ui(script.is_img2img):
331
ui_default_values.append(elem.value)
332
script_args[script.args_from:script.args_to] = ui_default_values
333
return script_args
334
335
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
336
script_args = default_script_args.copy()
337
338
if input_script_args is not None:
339
for index, value in input_script_args.items():
340
script_args[index] = value
341
342
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
343
if selectable_scripts:
344
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
345
script_args[0] = selectable_idx + 1
346
347
# Now check for always on scripts
348
if request.alwayson_scripts:
349
for alwayson_script_name in request.alwayson_scripts.keys():
350
alwayson_script = self.get_script(alwayson_script_name, script_runner)
351
if alwayson_script is None:
352
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
353
# Selectable script in always on script param check
354
if alwayson_script.alwayson is False:
355
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
356
# always on script with no arg should always run so you don't really need to add them to the requests
357
if "args" in request.alwayson_scripts[alwayson_script_name]:
358
# min between arg length in scriptrunner and arg length in the request
359
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
360
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
361
return script_args
362
363
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
364
"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
365
366
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
367
368
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
369
"""
370
371
if not request.infotext:
372
return {}
373
374
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
375
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this
376
params = infotext_utils.parse_generation_parameters(request.infotext)
377
378
def get_field_value(field, params):
379
value = field.function(params) if field.function else params.get(field.label)
380
if value is None:
381
return None
382
383
if field.api in request.__fields__:
384
target_type = request.__fields__[field.api].type_
385
else:
386
target_type = type(field.component.value)
387
388
if target_type == type(None):
389
return None
390
391
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
392
value = value.get('value')
393
394
if value is not None and not isinstance(value, target_type):
395
value = target_type(value)
396
397
return value
398
399
for field in possible_fields:
400
if not field.api:
401
continue
402
403
if field.api in set_fields:
404
continue
405
406
value = get_field_value(field, params)
407
if value is not None:
408
setattr(request, field.api, value)
409
410
if request.override_settings is None:
411
request.override_settings = {}
412
413
overridden_settings = infotext_utils.get_override_settings(params)
414
for _, setting_name, value in overridden_settings:
415
if setting_name not in request.override_settings:
416
request.override_settings[setting_name] = value
417
418
if script_runner is not None and mentioned_script_args is not None:
419
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
420
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
421
422
for field, index in script_fields:
423
value = get_field_value(field, params)
424
425
if value is None:
426
continue
427
428
mentioned_script_args[index] = value
429
430
return params
431
432
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
433
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
434
435
script_runner = scripts.scripts_txt2img
436
437
infotext_script_args = {}
438
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
439
440
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
441
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)
442
443
populate = txt2imgreq.copy(update={ # Override __init__ params
444
"sampler_name": validate_sampler_name(sampler),
445
"do_not_save_samples": not txt2imgreq.save_images,
446
"do_not_save_grid": not txt2imgreq.save_images,
447
})
448
if populate.sampler_name:
449
populate.sampler_index = None # prevent a warning later on
450
451
if not populate.scheduler and scheduler != "Automatic":
452
populate.scheduler = scheduler
453
454
args = vars(populate)
455
args.pop('script_name', None)
456
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
457
args.pop('alwayson_scripts', None)
458
args.pop('infotext', None)
459
460
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
461
462
send_images = args.pop('send_images', True)
463
args.pop('save_images', None)
464
465
add_task_to_queue(task_id)
466
467
with self.queue_lock:
468
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
469
p.is_api = True
470
p.scripts = script_runner
471
p.outpath_grids = opts.outdir_txt2img_grids
472
p.outpath_samples = opts.outdir_txt2img_samples
473
474
try:
475
shared.state.begin(job="scripts_txt2img")
476
start_task(task_id)
477
if selectable_scripts is not None:
478
p.script_args = script_args
479
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
480
else:
481
p.script_args = tuple(script_args) # Need to pass args as tuple here
482
processed = process_images(p)
483
finish_task(task_id)
484
finally:
485
shared.state.end()
486
shared.total_tqdm.clear()
487
488
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
489
490
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
491
492
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
493
task_id = img2imgreq.force_task_id or create_task_id("img2img")
494
495
init_images = img2imgreq.init_images
496
if init_images is None:
497
raise HTTPException(status_code=404, detail="Init image not found")
498
499
mask = img2imgreq.mask
500
if mask:
501
mask = decode_base64_to_image(mask)
502
503
script_runner = scripts.scripts_img2img
504
505
infotext_script_args = {}
506
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
507
508
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
509
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)
510
511
populate = img2imgreq.copy(update={ # Override __init__ params
512
"sampler_name": validate_sampler_name(sampler),
513
"do_not_save_samples": not img2imgreq.save_images,
514
"do_not_save_grid": not img2imgreq.save_images,
515
"mask": mask,
516
})
517
if populate.sampler_name:
518
populate.sampler_index = None # prevent a warning later on
519
520
if not populate.scheduler and scheduler != "Automatic":
521
populate.scheduler = scheduler
522
523
args = vars(populate)
524
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
525
args.pop('script_name', None)
526
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
527
args.pop('alwayson_scripts', None)
528
args.pop('infotext', None)
529
530
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
531
532
send_images = args.pop('send_images', True)
533
args.pop('save_images', None)
534
535
add_task_to_queue(task_id)
536
537
with self.queue_lock:
538
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
539
p.init_images = [decode_base64_to_image(x) for x in init_images]
540
p.is_api = True
541
p.scripts = script_runner
542
p.outpath_grids = opts.outdir_img2img_grids
543
p.outpath_samples = opts.outdir_img2img_samples
544
545
try:
546
shared.state.begin(job="scripts_img2img")
547
start_task(task_id)
548
if selectable_scripts is not None:
549
p.script_args = script_args
550
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
551
else:
552
p.script_args = tuple(script_args) # Need to pass args as tuple here
553
processed = process_images(p)
554
finish_task(task_id)
555
finally:
556
shared.state.end()
557
shared.total_tqdm.clear()
558
559
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
560
561
if not img2imgreq.include_init_images:
562
img2imgreq.init_images = None
563
img2imgreq.mask = None
564
565
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
566
567
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
568
reqDict = setUpscalers(req)
569
570
reqDict['image'] = decode_base64_to_image(reqDict['image'])
571
572
with self.queue_lock:
573
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
574
575
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
576
577
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
578
reqDict = setUpscalers(req)
579
580
image_list = reqDict.pop('imageList', [])
581
image_folder = [decode_base64_to_image(x.data) for x in image_list]
582
583
with self.queue_lock:
584
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
585
586
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
587
588
def pnginfoapi(self, req: models.PNGInfoRequest):
589
image = decode_base64_to_image(req.image.strip())
590
if image is None:
591
return models.PNGInfoResponse(info="")
592
593
geninfo, items = images.read_info_from_image(image)
594
if geninfo is None:
595
geninfo = ""
596
597
params = infotext_utils.parse_generation_parameters(geninfo)
598
script_callbacks.infotext_pasted_callback(geninfo, params)
599
600
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
601
602
def progressapi(self, req: models.ProgressRequest = Depends()):
603
# copy from check_progress_call of ui.py
604
605
if shared.state.job_count == 0:
606
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
607
608
# avoid dividing zero
609
progress = 0.01
610
611
if shared.state.job_count > 0:
612
progress += shared.state.job_no / shared.state.job_count
613
if shared.state.sampling_steps > 0:
614
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
615
616
time_since_start = time.time() - shared.state.time_start
617
eta = (time_since_start/progress)
618
eta_relative = eta-time_since_start
619
620
progress = min(progress, 1)
621
622
shared.state.set_current_image()
623
624
current_image = None
625
if shared.state.current_image and not req.skip_current_image:
626
current_image = encode_pil_to_base64(shared.state.current_image)
627
628
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
629
630
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
631
image_b64 = interrogatereq.image
632
if image_b64 is None:
633
raise HTTPException(status_code=404, detail="Image not found")
634
635
img = decode_base64_to_image(image_b64)
636
img = img.convert('RGB')
637
638
# Override object param
639
with self.queue_lock:
640
if interrogatereq.model == "clip":
641
processed = shared.interrogator.interrogate(img)
642
elif interrogatereq.model == "deepdanbooru":
643
processed = deepbooru.model.tag(img)
644
else:
645
raise HTTPException(status_code=404, detail="Model not found")
646
647
return models.InterrogateResponse(caption=processed)
648
649
def interruptapi(self):
650
shared.state.interrupt()
651
652
return {}
653
654
def unloadapi(self):
655
sd_models.unload_model_weights()
656
657
return {}
658
659
def reloadapi(self):
660
sd_models.send_model_to_device(shared.sd_model)
661
662
return {}
663
664
def skip(self):
665
shared.state.skip()
666
667
def get_config(self):
668
options = {}
669
for key in shared.opts.data.keys():
670
metadata = shared.opts.data_labels.get(key)
671
if(metadata is not None):
672
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
673
else:
674
options.update({key: shared.opts.data.get(key, None)})
675
676
return options
677
678
def set_config(self, req: dict[str, Any]):
679
checkpoint_name = req.get("sd_model_checkpoint", None)
680
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
681
raise RuntimeError(f"model {checkpoint_name!r} not found")
682
683
for k, v in req.items():
684
shared.opts.set(k, v, is_api=True)
685
686
shared.opts.save(shared.config_filename)
687
return
688
689
def get_cmd_flags(self):
690
return vars(shared.cmd_opts)
691
692
def get_samplers(self):
693
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
694
695
def get_schedulers(self):
696
return [
697
{
698
"name": scheduler.name,
699
"label": scheduler.label,
700
"aliases": scheduler.aliases,
701
"default_rho": scheduler.default_rho,
702
"need_inner_model": scheduler.need_inner_model,
703
}
704
for scheduler in sd_schedulers.schedulers]
705
706
def get_upscalers(self):
707
return [
708
{
709
"name": upscaler.name,
710
"model_name": upscaler.scaler.model_name,
711
"model_path": upscaler.data_path,
712
"model_url": None,
713
"scale": upscaler.scale,
714
}
715
for upscaler in shared.sd_upscalers
716
]
717
718
def get_latent_upscale_modes(self):
719
return [
720
{
721
"name": upscale_mode,
722
}
723
for upscale_mode in [*(shared.latent_upscale_modes or {})]
724
]
725
726
def get_sd_models(self):
727
import modules.sd_models as sd_models
728
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
729
730
def get_sd_vaes(self):
731
import modules.sd_vae as sd_vae
732
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
733
734
def get_hypernetworks(self):
735
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
736
737
def get_face_restorers(self):
738
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
739
740
def get_realesrgan_models(self):
741
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
742
743
def get_prompt_styles(self):
744
styleList = []
745
for k in shared.prompt_styles.styles:
746
style = shared.prompt_styles.styles[k]
747
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
748
749
return styleList
750
751
def get_embeddings(self):
752
db = sd_hijack.model_hijack.embedding_db
753
754
def convert_embedding(embedding):
755
return {
756
"step": embedding.step,
757
"sd_checkpoint": embedding.sd_checkpoint,
758
"sd_checkpoint_name": embedding.sd_checkpoint_name,
759
"shape": embedding.shape,
760
"vectors": embedding.vectors,
761
}
762
763
def convert_embeddings(embeddings):
764
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
765
766
return {
767
"loaded": convert_embeddings(db.word_embeddings),
768
"skipped": convert_embeddings(db.skipped_embeddings),
769
}
770
771
def refresh_embeddings(self):
772
with self.queue_lock:
773
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
774
775
def refresh_checkpoints(self):
776
with self.queue_lock:
777
shared.refresh_checkpoints()
778
779
def refresh_vae(self):
780
with self.queue_lock:
781
shared_items.refresh_vae_list()
782
783
def create_embedding(self, args: dict):
784
try:
785
shared.state.begin(job="create_embedding")
786
filename = create_embedding(**args) # create empty embedding
787
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
788
return models.CreateResponse(info=f"create embedding filename: {filename}")
789
except AssertionError as e:
790
return models.TrainResponse(info=f"create embedding error: {e}")
791
finally:
792
shared.state.end()
793
794
795
def create_hypernetwork(self, args: dict):
796
try:
797
shared.state.begin(job="create_hypernetwork")
798
filename = create_hypernetwork(**args) # create empty embedding
799
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
800
except AssertionError as e:
801
return models.TrainResponse(info=f"create hypernetwork error: {e}")
802
finally:
803
shared.state.end()
804
805
def train_embedding(self, args: dict):
806
try:
807
shared.state.begin(job="train_embedding")
808
apply_optimizations = shared.opts.training_xattention_optimizations
809
error = None
810
filename = ''
811
if not apply_optimizations:
812
sd_hijack.undo_optimizations()
813
try:
814
embedding, filename = train_embedding(**args) # can take a long time to complete
815
except Exception as e:
816
error = e
817
finally:
818
if not apply_optimizations:
819
sd_hijack.apply_optimizations()
820
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
821
except Exception as msg:
822
return models.TrainResponse(info=f"train embedding error: {msg}")
823
finally:
824
shared.state.end()
825
826
def train_hypernetwork(self, args: dict):
827
try:
828
shared.state.begin(job="train_hypernetwork")
829
shared.loaded_hypernetworks = []
830
apply_optimizations = shared.opts.training_xattention_optimizations
831
error = None
832
filename = ''
833
if not apply_optimizations:
834
sd_hijack.undo_optimizations()
835
try:
836
hypernetwork, filename = train_hypernetwork(**args)
837
except Exception as e:
838
error = e
839
finally:
840
shared.sd_model.cond_stage_model.to(devices.device)
841
shared.sd_model.first_stage_model.to(devices.device)
842
if not apply_optimizations:
843
sd_hijack.apply_optimizations()
844
shared.state.end()
845
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
846
except Exception as exc:
847
return models.TrainResponse(info=f"train embedding error: {exc}")
848
finally:
849
shared.state.end()
850
851
def get_memory(self):
852
try:
853
import os
854
import psutil
855
process = psutil.Process(os.getpid())
856
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
857
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
858
ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
859
except Exception as err:
860
ram = { 'error': f'{err}' }
861
try:
862
import torch
863
if torch.cuda.is_available():
864
s = torch.cuda.mem_get_info()
865
system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
866
s = dict(torch.cuda.memory_stats(shared.device))
867
allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
868
reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
869
active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
870
inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
871
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
872
cuda = {
873
'system': system,
874
'active': active,
875
'allocated': allocated,
876
'reserved': reserved,
877
'inactive': inactive,
878
'events': warnings,
879
}
880
else:
881
cuda = {'error': 'unavailable'}
882
except Exception as err:
883
cuda = {'error': f'{err}'}
884
return models.MemoryResponse(ram=ram, cuda=cuda)
885
886
def get_extensions_list(self):
887
from modules import extensions
888
extensions.list_extensions()
889
ext_list = []
890
for ext in extensions.extensions:
891
ext: extensions.Extension
892
ext.read_info_from_repo()
893
if ext.remote is not None:
894
ext_list.append({
895
"name": ext.name,
896
"remote": ext.remote,
897
"branch": ext.branch,
898
"commit_hash":ext.commit_hash,
899
"commit_date":ext.commit_date,
900
"version":ext.version,
901
"enabled":ext.enabled
902
})
903
return ext_list
904
905
def launch(self, server_name, port, root_path):
906
self.app.include_router(self.router)
907
uvicorn.run(
908
self.app,
909
host=server_name,
910
port=port,
911
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
912
root_path=root_path,
913
ssl_keyfile=shared.cmd_opts.tls_keyfile,
914
ssl_certfile=shared.cmd_opts.tls_certfile
915
)
916
917
def kill_webui(self):
918
restart.stop_program()
919
920
def restart_webui(self):
921
if restart.is_restartable():
922
restart.restart_program()
923
return Response(status_code=501)
924
925
def stop_webui(request):
926
shared.state.server_command = "stop"
927
return Response("Stopping.")
928
929
930