Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| # code taken from https://github.com/ykasten/layered-neural-atlases | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def positionalEncoding_vec(in_tensor, b): | |
| proj = torch.einsum("ij, k -> ijk", in_tensor, b) # shape (batch, in_tensor.size(1), freqNum) | |
| mapped_coords = torch.cat((torch.sin(proj), torch.cos(proj)), dim=1) # shape (batch, 2*in_tensor.size(1), freqNum) | |
| output = mapped_coords.transpose(2, 1).contiguous().view(mapped_coords.size(0), -1) | |
| return output | |
| class IMLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| output_dim, | |
| hidden_dim=256, | |
| use_positional=True, | |
| positional_dim=10, | |
| skip_layers=[4, 6], | |
| num_layers=8, # includes the output layer | |
| verbose=True, | |
| use_tanh=True, | |
| apply_softmax=False, | |
| ): | |
| super(IMLP, self).__init__() | |
| self.verbose = verbose | |
| self.use_tanh = use_tanh | |
| self.apply_softmax = apply_softmax | |
| if apply_softmax: | |
| self.softmax = nn.Softmax() | |
| if use_positional: | |
| encoding_dimensions = 2 * input_dim * positional_dim | |
| self.b = torch.tensor([(2 ** j) * np.pi for j in range(positional_dim)], requires_grad=False) | |
| else: | |
| encoding_dimensions = input_dim | |
| self.hidden = nn.ModuleList() | |
| for i in range(num_layers): | |
| if i == 0: | |
| input_dims = encoding_dimensions | |
| elif i in skip_layers: | |
| input_dims = hidden_dim + encoding_dimensions | |
| else: | |
| input_dims = hidden_dim | |
| if i == num_layers - 1: | |
| # last layer | |
| self.hidden.append(nn.Linear(input_dims, output_dim, bias=True)) | |
| else: | |
| self.hidden.append(nn.Linear(input_dims, hidden_dim, bias=True)) | |
| self.skip_layers = skip_layers | |
| self.num_layers = num_layers | |
| self.positional_dim = positional_dim | |
| self.use_positional = use_positional | |
| if self.verbose: | |
| print(f"Model has {count_parameters(self)} params") | |
| def forward(self, x): | |
| if self.use_positional: | |
| if self.b.device != x.device: | |
| self.b = self.b.to(x.device) | |
| pos = positionalEncoding_vec(x, self.b) | |
| x = pos | |
| input = x.detach().clone() | |
| for i, layer in enumerate(self.hidden): | |
| if i > 0: | |
| x = F.relu(x) | |
| if i in self.skip_layers: | |
| x = torch.cat((x, input), 1) | |
| x = layer(x) | |
| if self.use_tanh: | |
| x = torch.tanh(x) | |
| if self.apply_softmax: | |
| x = self.softmax(x) | |
| return x | |