Spaces:
Running
on
T4
Running
on
T4
| # -*- coding: utf-8 -*- | |
| import os | |
| import torch | |
| from torch import nn as nn | |
| import torch.nn.functional as F | |
| class PixelLoss(nn.Module): | |
| def __init__(self) -> None: | |
| super(PixelLoss, self).__init__() | |
| self.criterion = torch.nn.L1Loss().cuda() # its default will take the mean of this batch | |
| def forward(self, gen_hr, org_hr, batch_idx): | |
| # Calculate general PSNR | |
| pixel_loss = self.criterion(gen_hr, org_hr) | |
| return pixel_loss | |
| class L1_Charbonnier_loss(nn.Module): | |
| """L1 Charbonnierloss.""" | |
| def __init__(self): | |
| super(L1_Charbonnier_loss, self).__init__() | |
| self.eps = 1e-6 # already use square root | |
| def forward(self, X, Y, batch_idx): | |
| diff = torch.add(X, -Y) | |
| error = torch.sqrt(diff * diff + self.eps) | |
| loss = torch.mean(error) | |
| return loss | |
| """ | |
| Created on Thu Dec 3 00:28:15 2020 | |
| @author: Yunpeng Li, Tianjin University | |
| """ | |
| class MS_SSIM_L1_LOSS(nn.Module): | |
| # Have to use cuda, otherwise the speed is too slow. | |
| def __init__(self, alpha, | |
| gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0], | |
| data_range = 1.0, | |
| K=(0.01, 0.4), | |
| compensation=1.0, | |
| cuda_dev=0,): | |
| super(MS_SSIM_L1_LOSS, self).__init__() | |
| self.DR = data_range | |
| self.C1 = (K[0] * data_range) ** 2 | |
| self.C2 = (K[1] * data_range) ** 2 | |
| self.pad = int(2 * gaussian_sigmas[-1]) | |
| self.alpha = alpha | |
| self.compensation=compensation | |
| filter_size = int(4 * gaussian_sigmas[-1] + 1) | |
| g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size)) | |
| for idx, sigma in enumerate(gaussian_sigmas): | |
| # r0,g0,b0,r1,g1,b1,...,rM,gM,bM | |
| g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) | |
| g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) | |
| g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma) | |
| self.g_masks = g_masks.cuda(cuda_dev) | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.writer = SummaryWriter() | |
| def _fspecial_gauss_1d(self, size, sigma): | |
| """Create 1-D gauss kernel | |
| Args: | |
| size (int): the size of gauss kernel | |
| sigma (float): sigma of normal distribution | |
| Returns: | |
| torch.Tensor: 1D kernel (size) | |
| """ | |
| coords = torch.arange(size).to(dtype=torch.float) | |
| coords -= size // 2 | |
| g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) | |
| g /= g.sum() | |
| return g.reshape(-1) | |
| def _fspecial_gauss_2d(self, size, sigma): | |
| """Create 2-D gauss kernel | |
| Args: | |
| size (int): the size of gauss kernel | |
| sigma (float): sigma of normal distribution | |
| Returns: | |
| torch.Tensor: 2D kernel (size x size) | |
| """ | |
| gaussian_vec = self._fspecial_gauss_1d(size, sigma) | |
| return torch.outer(gaussian_vec, gaussian_vec) | |
| def forward(self, x, y, batch_idx): | |
| ''' | |
| Args: | |
| x (tensor): the input for a tensor | |
| y (tensor): the input for another tensor | |
| batch_idx (int): the iteration now | |
| Returns: | |
| combined_loss (torch): loss value of L1 with MS-SSIM loss | |
| ''' | |
| # b, c, h, w = x.shape | |
| mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad) | |
| muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad) | |
| mux2 = mux * mux | |
| muy2 = muy * muy | |
| muxy = mux * muy | |
| sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2 | |
| sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2 | |
| sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy | |
| # l(j), cs(j) in MS-SSIM | |
| l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W] | |
| cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2) | |
| lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :] | |
| PIcs = cs.prod(dim=1) | |
| loss_ms_ssim = 1 - lM*PIcs # [B, H, W] | |
| loss_l1 = F.l1_loss(x, y, reduction='none') # [B, 3, H, W] | |
| # average l1 loss in 3 channels | |
| gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3), | |
| groups=3, padding=self.pad).mean(1) # [B, H, W] | |
| loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR | |
| loss_mix = self.compensation*loss_mix # Currently, we set compensation to 1.0 | |
| combined_loss = loss_mix.mean() | |
| self.writer.add_scalar('Loss/ms_ssim_loss-iteration', loss_ms_ssim.mean(), batch_idx) | |
| self.writer.add_scalar('Loss/l1_loss-iteration', gaussian_l1.mean(), batch_idx) | |
| return combined_loss | |