Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/modules/images.py
3055 views
1
from __future__ import annotations
2
3
import datetime
4
import functools
5
import pytz
6
import io
7
import math
8
import os
9
from collections import namedtuple
10
import re
11
12
import numpy as np
13
import piexif
14
import piexif.helper
15
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
16
# pillow_avif needs to be imported somewhere in code for it to work
17
import pillow_avif # noqa: F401
18
import string
19
import json
20
import hashlib
21
22
from modules import sd_samplers, shared, script_callbacks, errors
23
from modules.paths_internal import roboto_ttf_file
24
from modules.shared import opts
25
26
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
27
28
29
def get_font(fontsize: int):
30
try:
31
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
32
except Exception:
33
return ImageFont.truetype(roboto_ttf_file, fontsize)
34
35
36
def image_grid(imgs, batch_size=1, rows=None):
37
if rows is None:
38
if opts.n_rows > 0:
39
rows = opts.n_rows
40
elif opts.n_rows == 0:
41
rows = batch_size
42
elif opts.grid_prevent_empty_spots:
43
rows = math.floor(math.sqrt(len(imgs)))
44
while len(imgs) % rows != 0:
45
rows -= 1
46
else:
47
rows = math.sqrt(len(imgs))
48
rows = round(rows)
49
if rows > len(imgs):
50
rows = len(imgs)
51
52
cols = math.ceil(len(imgs) / rows)
53
54
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
55
script_callbacks.image_grid_callback(params)
56
57
w, h = map(max, zip(*(img.size for img in imgs)))
58
grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGB')
59
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=grid_background_color)
60
61
for i, img in enumerate(params.imgs):
62
img_w, img_h = img.size
63
w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2
64
grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset))
65
66
return grid
67
68
69
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
70
@property
71
def tile_count(self) -> int:
72
"""
73
The total number of tiles in the grid.
74
"""
75
return sum(len(row[2]) for row in self.tiles)
76
77
78
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
79
w, h = image.size
80
81
non_overlap_width = tile_w - overlap
82
non_overlap_height = tile_h - overlap
83
84
cols = math.ceil((w - overlap) / non_overlap_width)
85
rows = math.ceil((h - overlap) / non_overlap_height)
86
87
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
88
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
89
90
grid = Grid([], tile_w, tile_h, w, h, overlap)
91
for row in range(rows):
92
row_images = []
93
94
y = int(row * dy)
95
96
if y + tile_h >= h:
97
y = h - tile_h
98
99
for col in range(cols):
100
x = int(col * dx)
101
102
if x + tile_w >= w:
103
x = w - tile_w
104
105
tile = image.crop((x, y, x + tile_w, y + tile_h))
106
107
row_images.append([x, tile_w, tile])
108
109
grid.tiles.append([y, tile_h, row_images])
110
111
return grid
112
113
114
def combine_grid(grid):
115
def make_mask_image(r):
116
r = r * 255 / grid.overlap
117
r = r.astype(np.uint8)
118
return Image.fromarray(r, 'L')
119
120
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
121
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
122
123
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
124
for y, h, row in grid.tiles:
125
combined_row = Image.new("RGB", (grid.image_w, h))
126
for x, w, tile in row:
127
if x == 0:
128
combined_row.paste(tile, (0, 0))
129
continue
130
131
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
132
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
133
134
if y == 0:
135
combined_image.paste(combined_row, (0, 0))
136
continue
137
138
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
139
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
140
141
return combined_image
142
143
144
class GridAnnotation:
145
def __init__(self, text='', is_active=True):
146
self.text = text
147
self.is_active = is_active
148
self.size = None
149
150
151
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
152
153
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
154
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
155
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
156
157
def wrap(drawing, text, font, line_length):
158
lines = ['']
159
for word in text.split():
160
line = f'{lines[-1]} {word}'.strip()
161
if drawing.textlength(line, font=font) <= line_length:
162
lines[-1] = line
163
else:
164
lines.append(word)
165
return lines
166
167
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
168
for line in lines:
169
fnt = initial_fnt
170
fontsize = initial_fontsize
171
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
172
fontsize -= 1
173
fnt = get_font(fontsize)
174
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
175
176
if not line.is_active:
177
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
178
179
draw_y += line.size[1] + line_spacing
180
181
fontsize = (width + height) // 25
182
line_spacing = fontsize // 2
183
184
fnt = get_font(fontsize)
185
186
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
187
188
cols = im.width // width
189
rows = im.height // height
190
191
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
192
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
193
194
calc_img = Image.new("RGB", (1, 1), color_background)
195
calc_d = ImageDraw.Draw(calc_img)
196
197
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
198
items = [] + texts
199
texts.clear()
200
201
for line in items:
202
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
203
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
204
205
for line in texts:
206
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
207
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
208
line.allowed_width = allowed_width
209
210
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
211
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
212
213
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
214
215
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
216
217
for row in range(rows):
218
for col in range(cols):
219
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
220
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
221
222
d = ImageDraw.Draw(result)
223
224
for col in range(cols):
225
x = pad_left + (width + margin) * col + width / 2
226
y = pad_top / 2 - hor_text_heights[col] / 2
227
228
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
229
230
for row in range(rows):
231
x = pad_left / 2
232
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
233
234
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
235
236
return result
237
238
239
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
240
prompts = all_prompts[1:]
241
boundary = math.ceil(len(prompts) / 2)
242
243
prompts_horiz = prompts[:boundary]
244
prompts_vert = prompts[boundary:]
245
246
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
247
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
248
249
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
250
251
252
def resize_image(resize_mode, im, width, height, upscaler_name=None):
253
"""
254
Resizes an image with the specified resize_mode, width, and height.
255
256
Args:
257
resize_mode: The mode to use when resizing the image.
258
0: Resize the image to the specified width and height.
259
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
260
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
261
im: The image to resize.
262
width: The width to resize the image to.
263
height: The height to resize the image to.
264
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
265
"""
266
267
upscaler_name = upscaler_name or opts.upscaler_for_img2img
268
269
def resize(im, w, h):
270
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
271
return im.resize((w, h), resample=LANCZOS)
272
273
scale = max(w / im.width, h / im.height)
274
275
if scale > 1.0:
276
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
277
if len(upscalers) == 0:
278
upscaler = shared.sd_upscalers[0]
279
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
280
else:
281
upscaler = upscalers[0]
282
283
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
284
285
if im.width != w or im.height != h:
286
im = im.resize((w, h), resample=LANCZOS)
287
288
return im
289
290
if resize_mode == 0:
291
res = resize(im, width, height)
292
293
elif resize_mode == 1:
294
ratio = width / height
295
src_ratio = im.width / im.height
296
297
src_w = width if ratio > src_ratio else im.width * height // im.height
298
src_h = height if ratio <= src_ratio else im.height * width // im.width
299
300
resized = resize(im, src_w, src_h)
301
res = Image.new("RGB", (width, height))
302
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
303
304
else:
305
ratio = width / height
306
src_ratio = im.width / im.height
307
308
src_w = width if ratio < src_ratio else im.width * height // im.height
309
src_h = height if ratio >= src_ratio else im.height * width // im.width
310
311
resized = resize(im, src_w, src_h)
312
res = Image.new("RGB", (width, height))
313
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
314
315
if ratio < src_ratio:
316
fill_height = height // 2 - src_h // 2
317
if fill_height > 0:
318
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
319
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
320
elif ratio > src_ratio:
321
fill_width = width // 2 - src_w // 2
322
if fill_width > 0:
323
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
324
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
325
326
return res
327
328
329
if not shared.cmd_opts.unix_filenames_sanitization:
330
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
331
else:
332
invalid_filename_chars = '/'
333
invalid_filename_prefix = ' '
334
invalid_filename_postfix = ' .'
335
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
336
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
337
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
338
max_filename_part_length = shared.cmd_opts.filenames_max_length
339
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
340
341
342
def sanitize_filename_part(text, replace_spaces=True):
343
if text is None:
344
return None
345
346
if replace_spaces:
347
text = text.replace(' ', '_')
348
349
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
350
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
351
text = text.rstrip(invalid_filename_postfix)
352
return text
353
354
355
@functools.cache
356
def get_scheduler_str(sampler_name, scheduler_name):
357
"""Returns {Scheduler} if the scheduler is applicable to the sampler"""
358
if scheduler_name == 'Automatic':
359
config = sd_samplers.find_sampler_config(sampler_name)
360
scheduler_name = config.options.get('scheduler', 'Automatic')
361
return scheduler_name.capitalize()
362
363
364
@functools.cache
365
def get_sampler_scheduler_str(sampler_name, scheduler_name):
366
"""Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler"""
367
return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}'
368
369
370
def get_sampler_scheduler(p, sampler):
371
"""Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'"""
372
if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'):
373
if sampler:
374
sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler)
375
else:
376
sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler)
377
return sanitize_filename_part(sampler_scheduler, replace_spaces=False)
378
return NOTHING_AND_SKIP_PREVIOUS_TEXT
379
380
381
class FilenameGenerator:
382
replacements = {
383
'basename': lambda self: self.basename or 'img',
384
'seed': lambda self: self.seed if self.seed is not None else '',
385
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
386
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
387
'steps': lambda self: self.p and self.p.steps,
388
'cfg': lambda self: self.p and self.p.cfg_scale,
389
'width': lambda self: self.image.width,
390
'height': lambda self: self.image.height,
391
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
392
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
393
'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True),
394
'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False),
395
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
396
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
397
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
398
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
399
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
400
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
401
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
402
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
403
'prompt': lambda self: sanitize_filename_part(self.prompt),
404
'prompt_no_styles': lambda self: self.prompt_no_style(),
405
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
406
'prompt_words': lambda self: self.prompt_words(),
407
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
408
'batch_size': lambda self: self.p.batch_size,
409
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
410
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
411
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
412
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
413
'user': lambda self: self.p.user,
414
'vae_filename': lambda self: self.get_vae_filename(),
415
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
416
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
417
}
418
default_time_format = '%Y%m%d%H%M%S'
419
420
def __init__(self, p, seed, prompt, image, zip=False, basename=""):
421
self.p = p
422
self.seed = seed
423
self.prompt = prompt
424
self.image = image
425
self.zip = zip
426
self.basename = basename
427
428
def get_vae_filename(self):
429
"""Get the name of the VAE file."""
430
431
import modules.sd_vae as sd_vae
432
433
if sd_vae.loaded_vae_file is None:
434
return "NoneType"
435
436
file_name = os.path.basename(sd_vae.loaded_vae_file)
437
split_file_name = file_name.split('.')
438
if len(split_file_name) > 1 and split_file_name[0] == '':
439
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
440
else:
441
return split_file_name[0]
442
443
444
def hasprompt(self, *args):
445
lower = self.prompt.lower()
446
if self.p is None or self.prompt is None:
447
return None
448
outres = ""
449
for arg in args:
450
if arg != "":
451
division = arg.split("|")
452
expected = division[0].lower()
453
default = division[1] if len(division) > 1 else ""
454
if lower.find(expected) >= 0:
455
outres = f'{outres}{expected}'
456
else:
457
outres = outres if default == "" else f'{outres}{default}'
458
return sanitize_filename_part(outres)
459
460
def prompt_no_style(self):
461
if self.p is None or self.prompt is None:
462
return None
463
464
prompt_no_style = self.prompt
465
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
466
if style:
467
for part in style.split("{prompt}"):
468
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
469
470
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
471
472
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
473
474
def prompt_words(self):
475
words = [x for x in re_nonletters.split(self.prompt or "") if x]
476
if len(words) == 0:
477
words = ["empty"]
478
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
479
480
def datetime(self, *args):
481
time_datetime = datetime.datetime.now()
482
483
time_format = args[0] if (args and args[0] != "") else self.default_time_format
484
try:
485
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
486
except pytz.exceptions.UnknownTimeZoneError:
487
time_zone = None
488
489
time_zone_time = time_datetime.astimezone(time_zone)
490
try:
491
formatted_time = time_zone_time.strftime(time_format)
492
except (ValueError, TypeError):
493
formatted_time = time_zone_time.strftime(self.default_time_format)
494
495
return sanitize_filename_part(formatted_time, replace_spaces=False)
496
497
def image_hash(self, *args):
498
length = int(args[0]) if (args and args[0] != "") else None
499
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
500
501
def string_hash(self, text, *args):
502
length = int(args[0]) if (args and args[0] != "") else 8
503
return hashlib.sha256(text.encode()).hexdigest()[0:length]
504
505
def apply(self, x):
506
res = ''
507
508
for m in re_pattern.finditer(x):
509
text, pattern = m.groups()
510
511
if pattern is None:
512
res += text
513
continue
514
515
pattern_args = []
516
while True:
517
m = re_pattern_arg.match(pattern)
518
if m is None:
519
break
520
521
pattern, arg = m.groups()
522
pattern_args.insert(0, arg)
523
524
fun = self.replacements.get(pattern.lower())
525
if fun is not None:
526
try:
527
replacement = fun(self, *pattern_args)
528
except Exception:
529
replacement = None
530
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
531
532
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
533
continue
534
elif replacement is not None:
535
res += text + str(replacement)
536
continue
537
538
res += f'{text}[{pattern}]'
539
540
return res
541
542
543
def get_next_sequence_number(path, basename):
544
"""
545
Determines and returns the next sequence number to use when saving an image in the specified directory.
546
547
The sequence starts at 0.
548
"""
549
result = -1
550
if basename != '':
551
basename = f"{basename}-"
552
553
prefix_length = len(basename)
554
for p in os.listdir(path):
555
if p.startswith(basename):
556
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
557
try:
558
result = max(int(parts[0]), result)
559
except ValueError:
560
pass
561
562
return result + 1
563
564
565
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
566
"""
567
Saves image to filename, including geninfo as text information for generation info.
568
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
569
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
570
"""
571
572
if extension is None:
573
extension = os.path.splitext(filename)[1]
574
575
image_format = Image.registered_extensions()[extension]
576
577
if extension.lower() == '.png':
578
existing_pnginfo = existing_pnginfo or {}
579
if opts.enable_pnginfo:
580
existing_pnginfo[pnginfo_section_name] = geninfo
581
582
if opts.enable_pnginfo:
583
pnginfo_data = PngImagePlugin.PngInfo()
584
for k, v in (existing_pnginfo or {}).items():
585
pnginfo_data.add_text(k, str(v))
586
else:
587
pnginfo_data = None
588
589
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
590
591
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
592
if image.mode == 'RGBA':
593
image = image.convert("RGB")
594
elif image.mode == 'I;16':
595
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
596
597
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
598
599
if opts.enable_pnginfo and geninfo is not None:
600
exif_bytes = piexif.dump({
601
"Exif": {
602
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
603
},
604
})
605
606
piexif.insert(exif_bytes, filename)
607
elif extension.lower() == '.avif':
608
if opts.enable_pnginfo and geninfo is not None:
609
exif_bytes = piexif.dump({
610
"Exif": {
611
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
612
},
613
})
614
else:
615
exif_bytes = None
616
617
image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes)
618
elif extension.lower() == ".gif":
619
image.save(filename, format=image_format, comment=geninfo)
620
else:
621
image.save(filename, format=image_format, quality=opts.jpeg_quality)
622
623
624
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
625
"""Save an image.
626
627
Args:
628
image (`PIL.Image`):
629
The image to be saved.
630
path (`str`):
631
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
632
basename (`str`):
633
The base filename which will be applied to `filename pattern`.
634
seed, prompt, short_filename,
635
extension (`str`):
636
Image file extension, default is `png`.
637
pngsectionname (`str`):
638
Specify the name of the section which `info` will be saved in.
639
info (`str` or `PngImagePlugin.iTXt`):
640
PNG info chunks.
641
existing_info (`dict`):
642
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
643
no_prompt:
644
TODO I don't know its meaning.
645
p (`StableDiffusionProcessing`)
646
forced_filename (`str`):
647
If specified, `basename` and filename pattern will be ignored.
648
save_to_dirs (bool):
649
If true, the image will be saved into a subdirectory of `path`.
650
651
Returns: (fullfn, txt_fullfn)
652
fullfn (`str`):
653
The full path of the saved imaged.
654
txt_fullfn (`str` or None):
655
If a text file is saved for this image, this will be its full path. Otherwise None.
656
"""
657
namegen = FilenameGenerator(p, seed, prompt, image, basename=basename)
658
659
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
660
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
661
print('Image dimensions too large; saving as PNG')
662
extension = "png"
663
664
if save_to_dirs is None:
665
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
666
667
if save_to_dirs:
668
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
669
path = os.path.join(path, dirname)
670
671
os.makedirs(path, exist_ok=True)
672
673
if forced_filename is None:
674
if short_filename or seed is None:
675
file_decoration = ""
676
elif opts.save_to_dirs:
677
file_decoration = opts.samples_filename_pattern or "[seed]"
678
else:
679
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
680
681
file_decoration = namegen.apply(file_decoration) + suffix
682
683
add_number = opts.save_images_add_number or file_decoration == ''
684
685
if file_decoration != "" and add_number:
686
file_decoration = f"-{file_decoration}"
687
688
if add_number:
689
basecount = get_next_sequence_number(path, basename)
690
fullfn = None
691
for i in range(500):
692
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
693
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
694
if not os.path.exists(fullfn):
695
break
696
else:
697
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
698
else:
699
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
700
701
pnginfo = existing_info or {}
702
if info is not None:
703
pnginfo[pnginfo_section_name] = info
704
705
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
706
script_callbacks.before_image_saved_callback(params)
707
708
image = params.image
709
fullfn = params.filename
710
info = params.pnginfo.get(pnginfo_section_name, None)
711
712
def _atomically_save_image(image_to_save, filename_without_extension, extension):
713
"""
714
save image with .tmp extension to avoid race condition when another process detects new image in the directory
715
"""
716
temp_file_path = f"{filename_without_extension}.tmp"
717
718
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
719
720
filename = filename_without_extension + extension
721
if shared.opts.save_images_replace_action != "Replace":
722
n = 0
723
while os.path.exists(filename):
724
n += 1
725
filename = f"{filename_without_extension}-{n}{extension}"
726
os.replace(temp_file_path, filename)
727
728
fullfn_without_extension, extension = os.path.splitext(params.filename)
729
if hasattr(os, 'statvfs'):
730
max_name_len = os.statvfs(path).f_namemax
731
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
732
params.filename = fullfn_without_extension + extension
733
fullfn = params.filename
734
_atomically_save_image(image, fullfn_without_extension, extension)
735
736
image.already_saved_as = fullfn
737
738
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
739
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
740
ratio = image.width / image.height
741
resize_to = None
742
if oversize and ratio > 1:
743
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
744
elif oversize:
745
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
746
747
if resize_to is not None:
748
try:
749
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
750
image = image.resize(resize_to, LANCZOS)
751
except Exception:
752
image = image.resize(resize_to)
753
try:
754
_atomically_save_image(image, fullfn_without_extension, ".jpg")
755
except Exception as e:
756
errors.display(e, "saving image as downscaled JPG")
757
758
if opts.save_txt and info is not None:
759
txt_fullfn = f"{fullfn_without_extension}.txt"
760
with open(txt_fullfn, "w", encoding="utf8") as file:
761
file.write(f"{info}\n")
762
else:
763
txt_fullfn = None
764
765
script_callbacks.image_saved_callback(params)
766
767
return fullfn, txt_fullfn
768
769
770
IGNORED_INFO_KEYS = {
771
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
772
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
773
'icc_profile', 'chromaticity', 'photoshop',
774
}
775
776
777
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
778
items = (image.info or {}).copy()
779
780
geninfo = items.pop('parameters', None)
781
782
if "exif" in items:
783
exif_data = items["exif"]
784
try:
785
exif = piexif.load(exif_data)
786
except OSError:
787
# memory / exif was not valid so piexif tried to read from a file
788
exif = None
789
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
790
try:
791
exif_comment = piexif.helper.UserComment.load(exif_comment)
792
except ValueError:
793
exif_comment = exif_comment.decode('utf8', errors="ignore")
794
795
if exif_comment:
796
geninfo = exif_comment
797
elif "comment" in items: # for gif
798
if isinstance(items["comment"], bytes):
799
geninfo = items["comment"].decode('utf8', errors="ignore")
800
else:
801
geninfo = items["comment"]
802
803
for field in IGNORED_INFO_KEYS:
804
items.pop(field, None)
805
806
if items.get("Software", None) == "NovelAI":
807
try:
808
json_info = json.loads(items["Comment"])
809
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
810
811
geninfo = f"""{items["Description"]}
812
Negative prompt: {json_info["uc"]}
813
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
814
except Exception:
815
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
816
817
return geninfo, items
818
819
820
def image_data(data):
821
import gradio as gr
822
823
try:
824
image = read(io.BytesIO(data))
825
textinfo, _ = read_info_from_image(image)
826
return textinfo, None
827
except Exception:
828
pass
829
830
try:
831
text = data.decode('utf8')
832
assert len(text) < 10000
833
return text, None
834
835
except Exception:
836
pass
837
838
return gr.update(), None
839
840
841
def flatten(img, bgcolor):
842
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
843
844
if img.mode == "RGBA":
845
background = Image.new('RGBA', img.size, bgcolor)
846
background.paste(img, mask=img)
847
img = background
848
849
return img.convert('RGB')
850
851
852
def read(fp, **kwargs):
853
image = Image.open(fp, **kwargs)
854
image = fix_image(image)
855
856
return image
857
858
859
def fix_image(image: Image.Image):
860
if image is None:
861
return None
862
863
try:
864
image = ImageOps.exif_transpose(image)
865
image = fix_png_transparency(image)
866
except Exception:
867
pass
868
869
return image
870
871
872
def fix_png_transparency(image: Image.Image):
873
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
874
return image
875
876
image = image.convert("RGBA")
877
return image
878
879