Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/modules/commons/ssim.py
694 views
1
# '''
2
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
3
# '''
4
#
5
# import torch
6
# import torch.jit
7
# import torch.nn.functional as F
8
#
9
#
10
# @torch.jit.script
11
# def create_window(window_size: int, sigma: float, channel: int):
12
# '''
13
# Create 1-D gauss kernel
14
# :param window_size: the size of gauss kernel
15
# :param sigma: sigma of normal distribution
16
# :param channel: input channel
17
# :return: 1D kernel
18
# '''
19
# coords = torch.arange(window_size, dtype=torch.float)
20
# coords -= window_size // 2
21
#
22
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
23
# g /= g.sum()
24
#
25
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
26
# return g
27
#
28
#
29
# @torch.jit.script
30
# def _gaussian_filter(x, window_1d, use_padding: bool):
31
# '''
32
# Blur input with 1-D kernel
33
# :param x: batch of tensors to be blured
34
# :param window_1d: 1-D gauss kernel
35
# :param use_padding: padding image before conv
36
# :return: blured tensors
37
# '''
38
# C = x.shape[1]
39
# padding = 0
40
# if use_padding:
41
# window_size = window_1d.shape[3]
42
# padding = window_size // 2
43
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
44
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
45
# return out
46
#
47
#
48
# @torch.jit.script
49
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
50
# '''
51
# Calculate ssim index for X and Y
52
# :param X: images [B, C, H, N_bins]
53
# :param Y: images [B, C, H, N_bins]
54
# :param window: 1-D gauss kernel
55
# :param data_range: value range of input images. (usually 1.0 or 255)
56
# :param use_padding: padding image before conv
57
# :return:
58
# '''
59
#
60
# K1 = 0.01
61
# K2 = 0.03
62
# compensation = 1.0
63
#
64
# C1 = (K1 * data_range) ** 2
65
# C2 = (K2 * data_range) ** 2
66
#
67
# mu1 = _gaussian_filter(X, window, use_padding)
68
# mu2 = _gaussian_filter(Y, window, use_padding)
69
# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
70
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
71
# sigma12 = _gaussian_filter(X * Y, window, use_padding)
72
#
73
# mu1_sq = mu1.pow(2)
74
# mu2_sq = mu2.pow(2)
75
# mu1_mu2 = mu1 * mu2
76
#
77
# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
78
# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
79
# sigma12 = compensation * (sigma12 - mu1_mu2)
80
#
81
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
82
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
83
# cs_map = cs_map.clamp_min(0.)
84
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
85
#
86
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
87
# cs = cs_map.mean(dim=(1, 2, 3))
88
#
89
# return ssim_val, cs
90
#
91
#
92
# @torch.jit.script
93
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
94
# '''
95
# interface of ms-ssim
96
# :param X: a batch of images, (N,C,H,W)
97
# :param Y: a batch of images, (N,C,H,W)
98
# :param window: 1-D gauss kernel
99
# :param data_range: value range of input images. (usually 1.0 or 255)
100
# :param weights: weights for different levels
101
# :param use_padding: padding image before conv
102
# :param eps: use for avoid grad nan.
103
# :return:
104
# '''
105
# levels = weights.shape[0]
106
# cs_vals = []
107
# ssim_vals = []
108
# for _ in range(levels):
109
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
110
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
111
# ssim_val = ssim_val.clamp_min(eps)
112
# cs = cs.clamp_min(eps)
113
# cs_vals.append(cs)
114
#
115
# ssim_vals.append(ssim_val)
116
# padding = (X.shape[2] % 2, X.shape[3] % 2)
117
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
118
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
119
#
120
# cs_vals = torch.stack(cs_vals, dim=0)
121
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
122
# return ms_ssim_val
123
#
124
#
125
# class SSIM(torch.jit.ScriptModule):
126
# __constants__ = ['data_range', 'use_padding']
127
#
128
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
129
# '''
130
# :param window_size: the size of gauss kernel
131
# :param window_sigma: sigma of normal distribution
132
# :param data_range: value range of input images. (usually 1.0 or 255)
133
# :param channel: input channels (default: 3)
134
# :param use_padding: padding image before conv
135
# '''
136
# super().__init__()
137
# assert window_size % 2 == 1, 'Window size must be odd.'
138
# window = create_window(window_size, window_sigma, channel)
139
# self.register_buffer('window', window)
140
# self.data_range = data_range
141
# self.use_padding = use_padding
142
#
143
# @torch.jit.script_method
144
# def forward(self, X, Y):
145
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
146
# return r[0]
147
#
148
#
149
# class MS_SSIM(torch.jit.ScriptModule):
150
# __constants__ = ['data_range', 'use_padding', 'eps']
151
#
152
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
153
# levels=None, eps=1e-8):
154
# '''
155
# class for ms-ssim
156
# :param window_size: the size of gauss kernel
157
# :param window_sigma: sigma of normal distribution
158
# :param data_range: value range of input images. (usually 1.0 or 255)
159
# :param channel: input channels
160
# :param use_padding: padding image before conv
161
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
162
# :param levels: number of downsampling
163
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
164
# '''
165
# super().__init__()
166
# assert window_size % 2 == 1, 'Window size must be odd.'
167
# self.data_range = data_range
168
# self.use_padding = use_padding
169
# self.eps = eps
170
#
171
# window = create_window(window_size, window_sigma, channel)
172
# self.register_buffer('window', window)
173
#
174
# if weights is None:
175
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
176
# weights = torch.tensor(weights, dtype=torch.float)
177
#
178
# if levels is not None:
179
# weights = weights[:levels]
180
# weights = weights / weights.sum()
181
#
182
# self.register_buffer('weights', weights)
183
#
184
# @torch.jit.script_method
185
# def forward(self, X, Y):
186
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
187
# use_padding=self.use_padding, eps=self.eps)
188
#
189
#
190
# if __name__ == '__main__':
191
# print('Simple Test')
192
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
193
# img1 = im / 255
194
# img2 = img1 * 0.5
195
#
196
# losser = SSIM(data_range=1.).cuda()
197
# loss = losser(img1, img2).mean()
198
#
199
# losser2 = MS_SSIM(data_range=1.).cuda()
200
# loss2 = losser2(img1, img2).mean()
201
#
202
# print(loss.item())
203
# print(loss2.item())
204
#
205
# if __name__ == '__main__':
206
# print('Training Test')
207
# import cv2
208
# import torch.optim
209
# import numpy as np
210
# import imageio
211
# import time
212
#
213
# out_test_video = False
214
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
215
# video_use_gif = False
216
#
217
# im = cv2.imread('test_img1.jpg', 1)
218
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
219
#
220
# if out_test_video:
221
# if video_use_gif:
222
# fps = 0.5
223
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
224
# suffix = '.gif'
225
# else:
226
# fps = 5
227
# out_wh = (im.shape[1], im.shape[0])
228
# suffix = '.mkv'
229
# video_last_time = time.perf_counter()
230
# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
231
#
232
# # 测试ssim
233
# print('Training SSIM')
234
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
235
# rand_im.requires_grad = True
236
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
237
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
238
# ssim_score = 0
239
# while ssim_score < 0.999:
240
# optim.zero_grad()
241
# loss = losser(rand_im, t_im)
242
# (-loss).sum().backward()
243
# ssim_score = loss.item()
244
# optim.step()
245
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
246
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
247
#
248
# if out_test_video:
249
# if time.perf_counter() - video_last_time > 1. / fps:
250
# video_last_time = time.perf_counter()
251
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
252
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
253
# if isinstance(out_frame, cv2.UMat):
254
# out_frame = out_frame.get()
255
# video.append_data(out_frame)
256
#
257
# cv2.imshow('ssim', r_im)
258
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
259
# cv2.waitKey(1)
260
#
261
# if out_test_video:
262
# video.close()
263
#
264
# # 测试ms_ssim
265
# if out_test_video:
266
# if video_use_gif:
267
# fps = 0.5
268
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
269
# suffix = '.gif'
270
# else:
271
# fps = 5
272
# out_wh = (im.shape[1], im.shape[0])
273
# suffix = '.mkv'
274
# video_last_time = time.perf_counter()
275
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
276
#
277
# print('Training MS_SSIM')
278
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
279
# rand_im.requires_grad = True
280
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
281
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
282
# ssim_score = 0
283
# while ssim_score < 0.999:
284
# optim.zero_grad()
285
# loss = losser(rand_im, t_im)
286
# (-loss).sum().backward()
287
# ssim_score = loss.item()
288
# optim.step()
289
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
290
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
291
#
292
# if out_test_video:
293
# if time.perf_counter() - video_last_time > 1. / fps:
294
# video_last_time = time.perf_counter()
295
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
296
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
297
# if isinstance(out_frame, cv2.UMat):
298
# out_frame = out_frame.get()
299
# video.append_data(out_frame)
300
#
301
# cv2.imshow('ms_ssim', r_im)
302
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
303
# cv2.waitKey(1)
304
#
305
# if out_test_video:
306
# video.close()
307
308
"""
309
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
310
"""
311
312
import torch
313
import torch.nn.functional as F
314
from torch.autograd import Variable
315
import numpy as np
316
from math import exp
317
318
319
def gaussian(window_size, sigma):
320
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
321
return gauss / gauss.sum()
322
323
324
def create_window(window_size, channel):
325
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
326
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
327
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
328
return window
329
330
331
def _ssim(img1, img2, window, window_size, channel, size_average=True):
332
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
333
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
334
335
mu1_sq = mu1.pow(2)
336
mu2_sq = mu2.pow(2)
337
mu1_mu2 = mu1 * mu2
338
339
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
340
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
341
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
342
343
C1 = 0.01 ** 2
344
C2 = 0.03 ** 2
345
346
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
347
348
if size_average:
349
return ssim_map.mean()
350
else:
351
return ssim_map.mean(1)
352
353
354
class SSIM(torch.nn.Module):
355
def __init__(self, window_size=11, size_average=True):
356
super(SSIM, self).__init__()
357
self.window_size = window_size
358
self.size_average = size_average
359
self.channel = 1
360
self.window = create_window(window_size, self.channel)
361
362
def forward(self, img1, img2):
363
(_, channel, _, _) = img1.size()
364
365
if channel == self.channel and self.window.data.type() == img1.data.type():
366
window = self.window
367
else:
368
window = create_window(self.window_size, channel)
369
370
if img1.is_cuda:
371
window = window.cuda(img1.get_device())
372
window = window.type_as(img1)
373
374
self.window = window
375
self.channel = channel
376
377
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
378
379
380
window = None
381
382
383
def ssim(img1, img2, window_size=11, size_average=True):
384
(_, channel, _, _) = img1.size()
385
global window
386
if window is None:
387
window = create_window(window_size, channel)
388
if img1.is_cuda:
389
window = window.cuda(img1.get_device())
390
window = window.type_as(img1)
391
return _ssim(img1, img2, window, window_size, channel, size_average)
392
393