Spaces:
Running
on
T4
Running
on
T4
| import math | |
| import torch.nn as nn | |
| class Upsample(nn.Module): | |
| """Upsample module. | |
| Args: | |
| scale (int): Scale factor. Supported scales: 2^n and 3. | |
| num_feat (int): Channel number of intermediate features. | |
| """ | |
| def __init__(self, scale, num_feat): | |
| super(Upsample, self).__init__() | |
| m = [] | |
| if (scale & (scale - 1)) == 0: # scale = 2^n | |
| for _ in range(int(math.log(scale, 2))): | |
| m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | |
| m.append(nn.PixelShuffle(2)) | |
| elif scale == 3: | |
| m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | |
| m.append(nn.PixelShuffle(3)) | |
| else: | |
| raise ValueError( | |
| f"scale {scale} is not supported. " "Supported scales: 2^n and 3." | |
| ) | |
| self.up = nn.Sequential(*m) | |
| def forward(self, x): | |
| return self.up(x) | |
| class UpsampleOneStep(nn.Module): | |
| """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) | |
| Used in lightweight SR to save parameters. | |
| Args: | |
| scale (int): Scale factor. Supported scales: 2^n and 3. | |
| num_feat (int): Channel number of intermediate features. | |
| """ | |
| def __init__(self, scale, num_feat, num_out_ch): | |
| super(UpsampleOneStep, self).__init__() | |
| self.num_feat = num_feat | |
| m = [] | |
| m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) | |
| m.append(nn.PixelShuffle(scale)) | |
| self.up = nn.Sequential(*m) | |
| def forward(self, x): | |
| return self.up(x) | |