# '''1# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py2# '''3#4# import torch5# import torch.jit6# import torch.nn.functional as F7#8#9# @torch.jit.script10# def create_window(window_size: int, sigma: float, channel: int):11# '''12# Create 1-D gauss kernel13# :param window_size: the size of gauss kernel14# :param sigma: sigma of normal distribution15# :param channel: input channel16# :return: 1D kernel17# '''18# coords = torch.arange(window_size, dtype=torch.float)19# coords -= window_size // 220#21# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))22# g /= g.sum()23#24# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)25# return g26#27#28# @torch.jit.script29# def _gaussian_filter(x, window_1d, use_padding: bool):30# '''31# Blur input with 1-D kernel32# :param x: batch of tensors to be blured33# :param window_1d: 1-D gauss kernel34# :param use_padding: padding image before conv35# :return: blured tensors36# '''37# C = x.shape[1]38# padding = 039# if use_padding:40# window_size = window_1d.shape[3]41# padding = window_size // 242# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)43# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)44# return out45#46#47# @torch.jit.script48# def ssim(X, Y, window, data_range: float, use_padding: bool = False):49# '''50# Calculate ssim index for X and Y51# :param X: images [B, C, H, N_bins]52# :param Y: images [B, C, H, N_bins]53# :param window: 1-D gauss kernel54# :param data_range: value range of input images. (usually 1.0 or 255)55# :param use_padding: padding image before conv56# :return:57# '''58#59# K1 = 0.0160# K2 = 0.0361# compensation = 1.062#63# C1 = (K1 * data_range) ** 264# C2 = (K2 * data_range) ** 265#66# mu1 = _gaussian_filter(X, window, use_padding)67# mu2 = _gaussian_filter(Y, window, use_padding)68# sigma1_sq = _gaussian_filter(X * X, window, use_padding)69# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)70# sigma12 = _gaussian_filter(X * Y, window, use_padding)71#72# mu1_sq = mu1.pow(2)73# mu2_sq = mu2.pow(2)74# mu1_mu2 = mu1 * mu275#76# sigma1_sq = compensation * (sigma1_sq - mu1_sq)77# sigma2_sq = compensation * (sigma2_sq - mu2_sq)78# sigma12 = compensation * (sigma12 - mu1_mu2)79#80# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)81# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.82# cs_map = cs_map.clamp_min(0.)83# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map84#85# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW86# cs = cs_map.mean(dim=(1, 2, 3))87#88# return ssim_val, cs89#90#91# @torch.jit.script92# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):93# '''94# interface of ms-ssim95# :param X: a batch of images, (N,C,H,W)96# :param Y: a batch of images, (N,C,H,W)97# :param window: 1-D gauss kernel98# :param data_range: value range of input images. (usually 1.0 or 255)99# :param weights: weights for different levels100# :param use_padding: padding image before conv101# :param eps: use for avoid grad nan.102# :return:103# '''104# levels = weights.shape[0]105# cs_vals = []106# ssim_vals = []107# for _ in range(levels):108# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)109# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.110# ssim_val = ssim_val.clamp_min(eps)111# cs = cs.clamp_min(eps)112# cs_vals.append(cs)113#114# ssim_vals.append(ssim_val)115# padding = (X.shape[2] % 2, X.shape[3] % 2)116# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)117# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)118#119# cs_vals = torch.stack(cs_vals, dim=0)120# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)121# return ms_ssim_val122#123#124# class SSIM(torch.jit.ScriptModule):125# __constants__ = ['data_range', 'use_padding']126#127# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):128# '''129# :param window_size: the size of gauss kernel130# :param window_sigma: sigma of normal distribution131# :param data_range: value range of input images. (usually 1.0 or 255)132# :param channel: input channels (default: 3)133# :param use_padding: padding image before conv134# '''135# super().__init__()136# assert window_size % 2 == 1, 'Window size must be odd.'137# window = create_window(window_size, window_sigma, channel)138# self.register_buffer('window', window)139# self.data_range = data_range140# self.use_padding = use_padding141#142# @torch.jit.script_method143# def forward(self, X, Y):144# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)145# return r[0]146#147#148# class MS_SSIM(torch.jit.ScriptModule):149# __constants__ = ['data_range', 'use_padding', 'eps']150#151# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,152# levels=None, eps=1e-8):153# '''154# class for ms-ssim155# :param window_size: the size of gauss kernel156# :param window_sigma: sigma of normal distribution157# :param data_range: value range of input images. (usually 1.0 or 255)158# :param channel: input channels159# :param use_padding: padding image before conv160# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])161# :param levels: number of downsampling162# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.163# '''164# super().__init__()165# assert window_size % 2 == 1, 'Window size must be odd.'166# self.data_range = data_range167# self.use_padding = use_padding168# self.eps = eps169#170# window = create_window(window_size, window_sigma, channel)171# self.register_buffer('window', window)172#173# if weights is None:174# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]175# weights = torch.tensor(weights, dtype=torch.float)176#177# if levels is not None:178# weights = weights[:levels]179# weights = weights / weights.sum()180#181# self.register_buffer('weights', weights)182#183# @torch.jit.script_method184# def forward(self, X, Y):185# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,186# use_padding=self.use_padding, eps=self.eps)187#188#189# if __name__ == '__main__':190# print('Simple Test')191# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')192# img1 = im / 255193# img2 = img1 * 0.5194#195# losser = SSIM(data_range=1.).cuda()196# loss = losser(img1, img2).mean()197#198# losser2 = MS_SSIM(data_range=1.).cuda()199# loss2 = losser2(img1, img2).mean()200#201# print(loss.item())202# print(loss2.item())203#204# if __name__ == '__main__':205# print('Training Test')206# import cv2207# import torch.optim208# import numpy as np209# import imageio210# import time211#212# out_test_video = False213# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF214# video_use_gif = False215#216# im = cv2.imread('test_img1.jpg', 1)217# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.218#219# if out_test_video:220# if video_use_gif:221# fps = 0.5222# out_wh = (im.shape[1] // 2, im.shape[0] // 2)223# suffix = '.gif'224# else:225# fps = 5226# out_wh = (im.shape[1], im.shape[0])227# suffix = '.mkv'228# video_last_time = time.perf_counter()229# video = imageio.get_writer('ssim_test' + suffix, fps=fps)230#231# # 测试ssim232# print('Training SSIM')233# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.234# rand_im.requires_grad = True235# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)236# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()237# ssim_score = 0238# while ssim_score < 0.999:239# optim.zero_grad()240# loss = losser(rand_im, t_im)241# (-loss).sum().backward()242# ssim_score = loss.item()243# optim.step()244# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]245# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)246#247# if out_test_video:248# if time.perf_counter() - video_last_time > 1. / fps:249# video_last_time = time.perf_counter()250# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)251# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)252# if isinstance(out_frame, cv2.UMat):253# out_frame = out_frame.get()254# video.append_data(out_frame)255#256# cv2.imshow('ssim', r_im)257# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)258# cv2.waitKey(1)259#260# if out_test_video:261# video.close()262#263# # 测试ms_ssim264# if out_test_video:265# if video_use_gif:266# fps = 0.5267# out_wh = (im.shape[1] // 2, im.shape[0] // 2)268# suffix = '.gif'269# else:270# fps = 5271# out_wh = (im.shape[1], im.shape[0])272# suffix = '.mkv'273# video_last_time = time.perf_counter()274# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)275#276# print('Training MS_SSIM')277# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.278# rand_im.requires_grad = True279# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)280# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()281# ssim_score = 0282# while ssim_score < 0.999:283# optim.zero_grad()284# loss = losser(rand_im, t_im)285# (-loss).sum().backward()286# ssim_score = loss.item()287# optim.step()288# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]289# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)290#291# if out_test_video:292# if time.perf_counter() - video_last_time > 1. / fps:293# video_last_time = time.perf_counter()294# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)295# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)296# if isinstance(out_frame, cv2.UMat):297# out_frame = out_frame.get()298# video.append_data(out_frame)299#300# cv2.imshow('ms_ssim', r_im)301# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)302# cv2.waitKey(1)303#304# if out_test_video:305# video.close()306307"""308Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim309"""310311import torch312import torch.nn.functional as F313from torch.autograd import Variable314import numpy as np315from math import exp316317318def gaussian(window_size, sigma):319gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])320return gauss / gauss.sum()321322323def create_window(window_size, channel):324_1D_window = gaussian(window_size, 1.5).unsqueeze(1)325_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)326window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())327return window328329330def _ssim(img1, img2, window, window_size, channel, size_average=True):331mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)332mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)333334mu1_sq = mu1.pow(2)335mu2_sq = mu2.pow(2)336mu1_mu2 = mu1 * mu2337338sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq339sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq340sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2341342C1 = 0.01 ** 2343C2 = 0.03 ** 2344345ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))346347if size_average:348return ssim_map.mean()349else:350return ssim_map.mean(1)351352353class SSIM(torch.nn.Module):354def __init__(self, window_size=11, size_average=True):355super(SSIM, self).__init__()356self.window_size = window_size357self.size_average = size_average358self.channel = 1359self.window = create_window(window_size, self.channel)360361def forward(self, img1, img2):362(_, channel, _, _) = img1.size()363364if channel == self.channel and self.window.data.type() == img1.data.type():365window = self.window366else:367window = create_window(self.window_size, channel)368369if img1.is_cuda:370window = window.cuda(img1.get_device())371window = window.type_as(img1)372373self.window = window374self.channel = channel375376return _ssim(img1, img2, window, self.window_size, channel, self.size_average)377378379window = None380381382def ssim(img1, img2, window_size=11, size_average=True):383(_, channel, _, _) = img1.size()384global window385if window is None:386window = create_window(window_size, channel)387if img1.is_cuda:388window = window.cuda(img1.get_device())389window = window.type_as(img1)390return _ssim(img1, img2, window, window_size, channel, size_average)391392393