Path: blob/master/modules/face_restoration_utils.py
3058 views
from __future__ import annotations12import logging3import os4from functools import cached_property5from typing import TYPE_CHECKING, Callable67import cv28import numpy as np9import torch1011from modules import devices, errors, face_restoration, shared1213if TYPE_CHECKING:14from facexlib.utils.face_restoration_helper import FaceRestoreHelper1516logger = logging.getLogger(__name__)171819def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:20"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""21assert img.shape[2] == 3, "image must be RGB"22if img.dtype == "float64":23img = img.astype("float32")24img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)25return torch.from_numpy(img.transpose(2, 0, 1)).float()262728def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:29"""30Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.31"""32tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)33tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])34assert tensor.dim() == 3, "tensor must be RGB"35img_np = tensor.numpy().transpose(1, 2, 0)36if img_np.shape[2] == 1: # gray image, no RGB/BGR required37return np.squeeze(img_np, axis=2)38return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)394041def create_face_helper(device) -> FaceRestoreHelper:42from facexlib.detection import retinaface43from facexlib.utils.face_restoration_helper import FaceRestoreHelper44if hasattr(retinaface, 'device'):45retinaface.device = device46return FaceRestoreHelper(47upscale_factor=1,48face_size=512,49crop_ratio=(1, 1),50det_model='retinaface_resnet50',51save_ext='png',52use_parse=True,53device=device,54)555657def restore_with_face_helper(58np_image: np.ndarray,59face_helper: FaceRestoreHelper,60restore_face: Callable[[torch.Tensor], torch.Tensor],61) -> np.ndarray:62"""63Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.6465`restore_face` should take a cropped face image and return a restored face image.66"""67from torchvision.transforms.functional import normalize68np_image = np_image[:, :, ::-1]69original_resolution = np_image.shape[0:2]7071try:72logger.debug("Detecting faces...")73face_helper.clean_all()74face_helper.read_image(np_image)75face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)76face_helper.align_warp_face()77logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))78for cropped_face in face_helper.cropped_faces:79cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)80normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)81cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)8283try:84with torch.no_grad():85cropped_face_t = restore_face(cropped_face_t)86devices.torch_gc()87except Exception:88errors.report('Failed face-restoration inference', exc_info=True)8990restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))91restored_face = (restored_face * 255.0).astype('uint8')92face_helper.add_restored_face(restored_face)9394logger.debug("Merging restored faces into image")95face_helper.get_inverse_affine(None)96img = face_helper.paste_faces_to_input_image()97img = img[:, :, ::-1]98if original_resolution != img.shape[0:2]:99img = cv2.resize(100img,101(0, 0),102fx=original_resolution[1] / img.shape[1],103fy=original_resolution[0] / img.shape[0],104interpolation=cv2.INTER_LINEAR,105)106logger.debug("Face restoration complete")107finally:108face_helper.clean_all()109return img110111112class CommonFaceRestoration(face_restoration.FaceRestoration):113net: torch.Module | None114model_url: str115model_download_name: str116117def __init__(self, model_path: str):118super().__init__()119self.net = None120self.model_path = model_path121os.makedirs(model_path, exist_ok=True)122123@cached_property124def face_helper(self) -> FaceRestoreHelper:125return create_face_helper(self.get_device())126127def send_model_to(self, device):128if self.net:129logger.debug("Sending %s to %s", self.net, device)130self.net.to(device)131if self.face_helper:132logger.debug("Sending face helper to %s", device)133self.face_helper.face_det.to(device)134self.face_helper.face_parse.to(device)135136def get_device(self):137raise NotImplementedError("get_device must be implemented by subclasses")138139def load_net(self) -> torch.Module:140raise NotImplementedError("load_net must be implemented by subclasses")141142def restore_with_helper(143self,144np_image: np.ndarray,145restore_face: Callable[[torch.Tensor], torch.Tensor],146) -> np.ndarray:147try:148if self.net is None:149self.net = self.load_net()150except Exception:151logger.warning("Unable to load face-restoration model", exc_info=True)152return np_image153154try:155self.send_model_to(self.get_device())156return restore_with_face_helper(np_image, self.face_helper, restore_face)157finally:158if shared.opts.face_restoration_unload:159self.send_model_to(devices.cpu)160161162def patch_facexlib(dirname: str) -> None:163import facexlib.detection164import facexlib.parsing165166det_facex_load_file_from_url = facexlib.detection.load_file_from_url167par_facex_load_file_from_url = facexlib.parsing.load_file_from_url168169def update_kwargs(kwargs):170return dict(kwargs, save_dir=dirname, model_dir=None)171172def facex_load_file_from_url(**kwargs):173return det_facex_load_file_from_url(**update_kwargs(kwargs))174175def facex_load_file_from_url2(**kwargs):176return par_facex_load_file_from_url(**update_kwargs(kwargs))177178facexlib.detection.load_file_from_url = facex_load_file_from_url179facexlib.parsing.load_file_from_url = facex_load_file_from_url2180181182