Path: blob/master/modules/hat_model.py
3055 views
import os1import sys23from modules import modelloader, devices4from modules.shared import opts5from modules.upscaler import Upscaler, UpscalerData6from modules.upscaler_utils import upscale_with_model789class UpscalerHAT(Upscaler):10def __init__(self, dirname):11self.name = "HAT"12self.scalers = []13self.user_path = dirname14super().__init__()15for file in self.find_models(ext_filter=[".pt", ".pth"]):16name = modelloader.friendly_name(file)17scale = 4 # TODO: scale might not be 4, but we can't know without loading the model18scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)19self.scalers.append(scaler_data)2021def do_upscale(self, img, selected_model):22try:23model = self.load_model(selected_model)24except Exception as e:25print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)26return img27model.to(devices.device_esrgan) # TODO: should probably be device_hat28return upscale_with_model(29model,30img,31tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile32tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap33)3435def load_model(self, path: str):36if not os.path.isfile(path):37raise FileNotFoundError(f"Model file {path} not found")38return modelloader.load_spandrel_model(39path,40device=devices.device_esrgan, # TODO: should probably be device_hat41expected_architecture='HAT',42)434445