Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from .downsampler import Downsampler | |
| def add_module(self, module): | |
| self.add_module(str(len(self) + 1), module) | |
| torch.nn.Module.add = add_module | |
| class Concat(nn.Module): | |
| def __init__(self, dim, *args): | |
| super(Concat, self).__init__() | |
| self.dim = dim | |
| for idx, module in enumerate(args): | |
| self.add_module(str(idx), module) | |
| def forward(self, input): | |
| inputs = [] | |
| for module in self._modules.values(): | |
| inputs.append(module(input)) | |
| inputs_shapes2 = [x.shape[2] for x in inputs] | |
| inputs_shapes3 = [x.shape[3] for x in inputs] | |
| if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( | |
| np.array(inputs_shapes3) == min(inputs_shapes3) | |
| ): | |
| inputs_ = inputs | |
| else: | |
| target_shape2 = min(inputs_shapes2) | |
| target_shape3 = min(inputs_shapes3) | |
| inputs_ = [] | |
| for inp in inputs: | |
| diff2 = (inp.size(2) - target_shape2) // 2 | |
| diff3 = (inp.size(3) - target_shape3) // 2 | |
| inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) | |
| return torch.cat(inputs_, dim=self.dim) | |
| def __len__(self): | |
| return len(self._modules) | |
| class GenNoise(nn.Module): | |
| def __init__(self, dim2): | |
| super(GenNoise, self).__init__() | |
| self.dim2 = dim2 | |
| def forward(self, input): | |
| a = list(input.size()) | |
| a[1] = self.dim2 | |
| # print (input.data.type()) | |
| b = torch.zeros(a).type_as(input.data) | |
| b.normal_() | |
| x = torch.autograd.Variable(b) | |
| return x | |
| class Swish(nn.Module): | |
| """ | |
| https://arxiv.org/abs/1710.05941 | |
| The hype was so huge that I could not help but try it | |
| """ | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| self.s = nn.Sigmoid() | |
| def forward(self, x): | |
| return x * self.s(x) | |
| def act(act_fun="LeakyReLU"): | |
| """ | |
| Either string defining an activation function or module (e.g. nn.ReLU) | |
| """ | |
| if isinstance(act_fun, str): | |
| if act_fun == "LeakyReLU": | |
| return nn.LeakyReLU(0.2, inplace=True) | |
| elif act_fun == "Swish": | |
| return Swish() | |
| elif act_fun == "ELU": | |
| return nn.ELU() | |
| elif act_fun == "none": | |
| return nn.Sequential() | |
| else: | |
| assert False | |
| else: | |
| return act_fun() | |
| class PixelNormLayer(nn.Module): | |
| """ | |
| Pixelwise feature vector normalization. | |
| """ | |
| def __init__(self, eps=1e-8): | |
| super(PixelNormLayer, self).__init__() | |
| self.eps = eps | |
| def forward(self, x): | |
| return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) | |
| def __repr__(self): | |
| return self.__class__.__name__ + "(eps = %s)" % (self.eps) | |
| def pixelnorm(num_features): | |
| return PixelNormLayer() | |
| def bn(num_features): | |
| return nn.BatchNorm2d(num_features) | |
| def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"): | |
| downsampler = None | |
| if stride != 1 and downsample_mode != "stride": | |
| if downsample_mode == "avg": | |
| downsampler = nn.AvgPool2d(stride, stride) | |
| elif downsample_mode == "max": | |
| downsampler = nn.MaxPool2d(stride, stride) | |
| elif downsample_mode in ["lanczos2", "lanczos3"]: | |
| downsampler = Downsampler( | |
| n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True | |
| ) | |
| else: | |
| assert False | |
| stride = 1 | |
| padder = None | |
| to_pad = int((kernel_size - 1) / 2) | |
| if pad == "reflection": | |
| padder = nn.ReflectionPad2d(to_pad) | |
| to_pad = 0 | |
| convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) | |
| layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) | |
| return nn.Sequential(*layers) | |
| class DecorrelatedColorsToRGB(nn.Module): | |
| """Converts from a decorrelated color space to RGB. See | |
| https://github.com/eps696/aphantasia/blob/master/aphantasia/image.py. Usually intended | |
| to be followed by a sigmoid. | |
| """ | |
| def __init__(self, inv_color_scale=1.6): | |
| super().__init__() | |
| color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) | |
| color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1.0, 1.0]) # saturate, empirical | |
| max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() | |
| color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt | |
| self.register_buffer("colcorr_t", color_correlation_normalized.T) | |
| def inverse(self, image): | |
| colcorr_t_inv = torch.linalg.inv(self.colcorr_t) | |
| return torch.einsum("nchw,cd->ndhw", image, colcorr_t_inv) | |
| def forward(self, image): | |
| if image.dim() == 4: | |
| image_rgb, alpha = image[:, :3], image[:, 3].unsqueeze(1) | |
| image_rgb = torch.einsum("nchw,cd->ndhw", image_rgb, self.colcorr_t) | |
| image = torch.cat([image_rgb, alpha], dim=1) | |
| else: | |
| image = torch.einsum("nchw,cd->ndhw", image, self.colcorr_t) | |
| return image | |