Path: blob/master/modules/deepbooru.py
3028 views
import os1import re23import torch4import numpy as np56from modules import modelloader, paths, deepbooru_model, devices, images, shared78re_special = re.compile(r'([\\()])')91011class DeepDanbooru:12def __init__(self):13self.model = None1415def load(self):16if self.model is not None:17return1819files = modelloader.load_models(20model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),21model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',22ext_filter=[".pt"],23download_name='model-resnet_custom_v3.pt',24)2526self.model = deepbooru_model.DeepDanbooruModel()27self.model.load_state_dict(torch.load(files[0], map_location="cpu"))2829self.model.eval()30self.model.to(devices.cpu, devices.dtype)3132def start(self):33self.load()34self.model.to(devices.device)3536def stop(self):37if not shared.opts.interrogate_keep_models_in_memory:38self.model.to(devices.cpu)39devices.torch_gc()4041def tag(self, pil_image):42self.start()43res = self.tag_multi(pil_image)44self.stop()4546return res4748def tag_multi(self, pil_image, force_disable_ranks=False):49threshold = shared.opts.interrogate_deepbooru_score_threshold50use_spaces = shared.opts.deepbooru_use_spaces51use_escape = shared.opts.deepbooru_escape52alpha_sort = shared.opts.deepbooru_sort_alpha53include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks5455pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)56a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 2555758with torch.no_grad(), devices.autocast():59x = torch.from_numpy(a).to(devices.device, devices.dtype)60y = self.model(x)[0].detach().cpu().numpy()6162probability_dict = {}6364for tag, probability in zip(self.model.tags, y):65if probability < threshold:66continue6768if tag.startswith("rating:"):69continue7071probability_dict[tag] = probability7273if alpha_sort:74tags = sorted(probability_dict)75else:76tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]7778res = []7980filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}8182for tag in [x for x in tags if x not in filtertags]:83probability = probability_dict[tag]84tag_outformat = tag85if use_spaces:86tag_outformat = tag_outformat.replace('_', ' ')87if use_escape:88tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)89if include_ranks:90tag_outformat = f"({tag_outformat}:{probability:.3f})"9192res.append(tag_outformat)9394return ", ".join(res)959697model = DeepDanbooru()9899100