Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- models/classifiers.py +172 -0
- models/efficientnet.onnx +3 -0
- models/image.py +195 -0
- models/links.txt +1 -0
- models/model.pth +3 -0
- models/rawnet.py +360 -0
models/classifiers.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
| 6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn.modules.dropout import Dropout
|
| 9 |
+
from torch.nn.modules.linear import Linear
|
| 10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
| 11 |
+
|
| 12 |
+
encoder_params = {
|
| 13 |
+
"tf_efficientnet_b3_ns": {
|
| 14 |
+
"features": 1536,
|
| 15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
| 16 |
+
},
|
| 17 |
+
"tf_efficientnet_b2_ns": {
|
| 18 |
+
"features": 1408,
|
| 19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
| 20 |
+
},
|
| 21 |
+
"tf_efficientnet_b4_ns": {
|
| 22 |
+
"features": 1792,
|
| 23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
| 24 |
+
},
|
| 25 |
+
"tf_efficientnet_b5_ns": {
|
| 26 |
+
"features": 2048,
|
| 27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
| 28 |
+
},
|
| 29 |
+
"tf_efficientnet_b4_ns_03d": {
|
| 30 |
+
"features": 1792,
|
| 31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
| 32 |
+
},
|
| 33 |
+
"tf_efficientnet_b5_ns_03d": {
|
| 34 |
+
"features": 2048,
|
| 35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
| 36 |
+
},
|
| 37 |
+
"tf_efficientnet_b5_ns_04d": {
|
| 38 |
+
"features": 2048,
|
| 39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
| 40 |
+
},
|
| 41 |
+
"tf_efficientnet_b6_ns": {
|
| 42 |
+
"features": 2304,
|
| 43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
| 44 |
+
},
|
| 45 |
+
"tf_efficientnet_b7_ns": {
|
| 46 |
+
"features": 2560,
|
| 47 |
+
"init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
|
| 48 |
+
},
|
| 49 |
+
"tf_efficientnet_b6_ns_04d": {
|
| 50 |
+
"features": 2304,
|
| 51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
| 52 |
+
},
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
| 57 |
+
"""Creates the SRM kernels for noise analysis."""
|
| 58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
| 59 |
+
srm_kernel = torch.from_numpy(np.array([
|
| 60 |
+
[ # srm 1/2 horiz
|
| 61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
| 64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 66 |
+
], [ # srm 1/4
|
| 67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
| 69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
| 70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
| 71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
| 72 |
+
], [ # srm 1/12
|
| 73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
| 74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
| 75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
| 76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
| 77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
| 78 |
+
]
|
| 79 |
+
])).float()
|
| 80 |
+
srm_kernel[0] /= 2
|
| 81 |
+
srm_kernel[1] /= 4
|
| 82 |
+
srm_kernel[2] /= 12
|
| 83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
| 87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
| 88 |
+
weights = setup_srm_weights(input_channels)
|
| 89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
| 92 |
+
return conv
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
| 96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
| 99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
| 100 |
+
self.srm_conv = setup_srm_layer(3)
|
| 101 |
+
self.dropout = Dropout(dropout_rate)
|
| 102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
noise = self.srm_conv(x)
|
| 106 |
+
x = self.encoder.forward_features(noise)
|
| 107 |
+
x = self.avg_pool(x).flatten(1)
|
| 108 |
+
x = self.dropout(x)
|
| 109 |
+
x = self.fc(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
| 116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, features: int, flatten=False):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
| 122 |
+
self.flatten = flatten
|
| 123 |
+
|
| 124 |
+
def fscore(self, x):
|
| 125 |
+
m = self.conv(x)
|
| 126 |
+
m = m.sigmoid().exp()
|
| 127 |
+
return m
|
| 128 |
+
|
| 129 |
+
def norm(self, x: torch.Tensor):
|
| 130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
input_x = x
|
| 134 |
+
x = self.fscore(x)
|
| 135 |
+
x = self.norm(x)
|
| 136 |
+
x = x * input_x
|
| 137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class DeepFakeClassifier(nn.Module):
|
| 142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
| 145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
| 146 |
+
self.dropout = Dropout(dropout_rate)
|
| 147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
x = self.encoder.forward_features(x)
|
| 151 |
+
x = self.avg_pool(x).flatten(1)
|
| 152 |
+
x = self.dropout(x)
|
| 153 |
+
x = self.fc(x)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
| 160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
| 163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
| 164 |
+
self.dropout = Dropout(dropout_rate)
|
| 165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
x = self.encoder.forward_features(x)
|
| 169 |
+
x = self.avg_pool(x).flatten(1)
|
| 170 |
+
x = self.dropout(x)
|
| 171 |
+
x = self.fc(x)
|
| 172 |
+
return x
|
models/efficientnet.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39774e1cc878ac2b587fd4dc1c96fba084c9fe5ee3106a43b560f6054a69ba26
|
| 3 |
+
size 133
|
models/image.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import wget
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from models.rawnet import SincConv, Residual_block
|
| 9 |
+
from models.classifiers import DeepFakeClassifier
|
| 10 |
+
|
| 11 |
+
class ImageEncoder(nn.Module):
|
| 12 |
+
def __init__(self, args):
|
| 13 |
+
super(ImageEncoder, self).__init__()
|
| 14 |
+
self.device = args.device
|
| 15 |
+
self.args = args
|
| 16 |
+
self.flatten = nn.Flatten()
|
| 17 |
+
self.sigmoid = nn.Sigmoid()
|
| 18 |
+
# self.fc = nn.Linear(in_features=2560, out_features = 2)
|
| 19 |
+
self.pretrained_image_encoder = args.pretrained_image_encoder
|
| 20 |
+
self.freeze_image_encoder = args.freeze_image_encoder
|
| 21 |
+
|
| 22 |
+
if self.pretrained_image_encoder == False:
|
| 23 |
+
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
|
| 24 |
+
|
| 25 |
+
else:
|
| 26 |
+
self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
|
| 27 |
+
self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
|
| 28 |
+
|
| 29 |
+
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
|
| 30 |
+
print("Loading pretrained image encoder...")
|
| 31 |
+
self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
|
| 32 |
+
print("Loaded pretrained image encoder.")
|
| 33 |
+
|
| 34 |
+
if self.freeze_image_encoder == True:
|
| 35 |
+
for idx, param in self.model.named_parameters():
|
| 36 |
+
param.requires_grad = False
|
| 37 |
+
|
| 38 |
+
# self.model.fc = nn.Identity()
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
x = self.model(x)
|
| 42 |
+
out = self.sigmoid(x)
|
| 43 |
+
# x = self.flatten(x)
|
| 44 |
+
# out = self.fc(x)
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RawNet(nn.Module):
|
| 49 |
+
def __init__(self, args):
|
| 50 |
+
super(RawNet, self).__init__()
|
| 51 |
+
|
| 52 |
+
self.device=args.device
|
| 53 |
+
self.filts = [20, [20, 20], [20, 128], [128, 128]]
|
| 54 |
+
|
| 55 |
+
self.Sinc_conv=SincConv(device=self.device,
|
| 56 |
+
out_channels = self.filts[0],
|
| 57 |
+
kernel_size = 1024,
|
| 58 |
+
in_channels = args.in_channels)
|
| 59 |
+
|
| 60 |
+
self.first_bn = nn.BatchNorm1d(num_features = self.filts[0])
|
| 61 |
+
self.selu = nn.SELU(inplace=True)
|
| 62 |
+
self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True))
|
| 63 |
+
self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1]))
|
| 64 |
+
self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
| 65 |
+
self.filts[2][0] = self.filts[2][1]
|
| 66 |
+
self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
| 67 |
+
self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
| 68 |
+
self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
|
| 69 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 70 |
+
|
| 71 |
+
self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1],
|
| 72 |
+
l_out_features = self.filts[1][-1])
|
| 73 |
+
self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1],
|
| 74 |
+
l_out_features = self.filts[1][-1])
|
| 75 |
+
self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1],
|
| 76 |
+
l_out_features = self.filts[2][-1])
|
| 77 |
+
self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1],
|
| 78 |
+
l_out_features = self.filts[2][-1])
|
| 79 |
+
self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1],
|
| 80 |
+
l_out_features = self.filts[2][-1])
|
| 81 |
+
self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1],
|
| 82 |
+
l_out_features = self.filts[2][-1])
|
| 83 |
+
|
| 84 |
+
self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1])
|
| 85 |
+
self.gru = nn.GRU(input_size = self.filts[2][-1],
|
| 86 |
+
hidden_size = args.gru_node,
|
| 87 |
+
num_layers = args.nb_gru_layer,
|
| 88 |
+
batch_first = True)
|
| 89 |
+
|
| 90 |
+
self.fc1_gru = nn.Linear(in_features = args.gru_node,
|
| 91 |
+
out_features = args.nb_fc_node)
|
| 92 |
+
|
| 93 |
+
self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
|
| 94 |
+
out_features = args.nb_classes ,bias=True)
|
| 95 |
+
|
| 96 |
+
self.sig = nn.Sigmoid()
|
| 97 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 98 |
+
self.pretrained_audio_encoder = args.pretrained_audio_encoder
|
| 99 |
+
self.freeze_audio_encoder = args.freeze_audio_encoder
|
| 100 |
+
|
| 101 |
+
if self.pretrained_audio_encoder == True:
|
| 102 |
+
print("Loading pretrained audio encoder")
|
| 103 |
+
ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
|
| 104 |
+
print("Loaded pretrained audio encoder")
|
| 105 |
+
self.load_state_dict(ckpt, strict = True)
|
| 106 |
+
|
| 107 |
+
if self.freeze_audio_encoder:
|
| 108 |
+
for param in self.parameters():
|
| 109 |
+
param.requires_grad = False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(self, x, y = None):
|
| 113 |
+
|
| 114 |
+
nb_samp = x.shape[0]
|
| 115 |
+
len_seq = x.shape[1]
|
| 116 |
+
x=x.view(nb_samp,1,len_seq)
|
| 117 |
+
|
| 118 |
+
x = self.Sinc_conv(x)
|
| 119 |
+
x = F.max_pool1d(torch.abs(x), 3)
|
| 120 |
+
x = self.first_bn(x)
|
| 121 |
+
x = self.selu(x)
|
| 122 |
+
|
| 123 |
+
x0 = self.block0(x)
|
| 124 |
+
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
|
| 125 |
+
y0 = self.fc_attention0(y0)
|
| 126 |
+
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
|
| 127 |
+
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
x1 = self.block1(x)
|
| 131 |
+
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
|
| 132 |
+
y1 = self.fc_attention1(y1)
|
| 133 |
+
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
|
| 134 |
+
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
|
| 135 |
+
|
| 136 |
+
x2 = self.block2(x)
|
| 137 |
+
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
|
| 138 |
+
y2 = self.fc_attention2(y2)
|
| 139 |
+
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
|
| 140 |
+
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
|
| 141 |
+
|
| 142 |
+
x3 = self.block3(x)
|
| 143 |
+
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
|
| 144 |
+
y3 = self.fc_attention3(y3)
|
| 145 |
+
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
|
| 146 |
+
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
|
| 147 |
+
|
| 148 |
+
x4 = self.block4(x)
|
| 149 |
+
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
|
| 150 |
+
y4 = self.fc_attention4(y4)
|
| 151 |
+
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
|
| 152 |
+
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
|
| 153 |
+
|
| 154 |
+
x5 = self.block5(x)
|
| 155 |
+
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
|
| 156 |
+
y5 = self.fc_attention5(y5)
|
| 157 |
+
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
|
| 158 |
+
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
|
| 159 |
+
|
| 160 |
+
x = self.bn_before_gru(x)
|
| 161 |
+
x = self.selu(x)
|
| 162 |
+
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
|
| 163 |
+
self.gru.flatten_parameters()
|
| 164 |
+
x, _ = self.gru(x)
|
| 165 |
+
x = x[:,-1,:]
|
| 166 |
+
x = self.fc1_gru(x)
|
| 167 |
+
x = self.fc2_gru(x)
|
| 168 |
+
output=self.logsoftmax(x)
|
| 169 |
+
|
| 170 |
+
return output
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _make_attention_fc(self, in_features, l_out_features):
|
| 175 |
+
|
| 176 |
+
l_fc = []
|
| 177 |
+
|
| 178 |
+
l_fc.append(nn.Linear(in_features = in_features,
|
| 179 |
+
out_features = l_out_features))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
return nn.Sequential(*l_fc)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _make_layer(self, nb_blocks, nb_filts, first = False):
|
| 187 |
+
layers = []
|
| 188 |
+
#def __init__(self, nb_filts, first = False):
|
| 189 |
+
for i in range(nb_blocks):
|
| 190 |
+
first = first if i == 0 else False
|
| 191 |
+
layers.append(Residual_block(nb_filts = nb_filts,
|
| 192 |
+
first = first))
|
| 193 |
+
if i == 0: nb_filts[0] = nb_filts[1]
|
| 194 |
+
|
| 195 |
+
return nn.Sequential(*layers)
|
models/links.txt
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
models/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0315e9ad76374c2e0f91249847d4b1c8ad8c2b20ac334836e8e79657daa4b63a
|
| 3 |
+
size 134
|
models/rawnet.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from torch.nn.parameter import Parameter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SincConv(nn.Module):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def to_mel(hz):
|
| 14 |
+
return 2595 * np.log10(1 + hz / 700)
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def to_hz(mel):
|
| 18 |
+
return 700 * (10 ** (mel / 2595) - 1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000,
|
| 22 |
+
stride=1, padding=0, dilation=1, bias=False, groups=1):
|
| 23 |
+
|
| 24 |
+
super(SincConv,self).__init__()
|
| 25 |
+
|
| 26 |
+
if in_channels != 1:
|
| 27 |
+
|
| 28 |
+
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
|
| 29 |
+
raise ValueError(msg)
|
| 30 |
+
|
| 31 |
+
self.out_channels = out_channels
|
| 32 |
+
self.kernel_size = kernel_size
|
| 33 |
+
self.sample_rate=sample_rate
|
| 34 |
+
|
| 35 |
+
# Forcing the filters to be odd (i.e, perfectly symmetrics)
|
| 36 |
+
if kernel_size%2==0:
|
| 37 |
+
self.kernel_size=self.kernel_size+1
|
| 38 |
+
|
| 39 |
+
self.device=device
|
| 40 |
+
self.stride = stride
|
| 41 |
+
self.padding = padding
|
| 42 |
+
self.dilation = dilation
|
| 43 |
+
|
| 44 |
+
if bias:
|
| 45 |
+
raise ValueError('SincConv does not support bias.')
|
| 46 |
+
if groups > 1:
|
| 47 |
+
raise ValueError('SincConv does not support groups.')
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# initialize filterbanks using Mel scale
|
| 51 |
+
NFFT = 512
|
| 52 |
+
f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1)
|
| 53 |
+
fmel=self.to_mel(f) # Hz to mel conversion
|
| 54 |
+
fmelmax=np.max(fmel)
|
| 55 |
+
fmelmin=np.min(fmel)
|
| 56 |
+
filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1)
|
| 57 |
+
filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion
|
| 58 |
+
self.mel=filbandwidthsf
|
| 59 |
+
self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1)
|
| 60 |
+
self.band_pass=torch.zeros(self.out_channels,self.kernel_size)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def forward(self,x):
|
| 65 |
+
for i in range(len(self.mel)-1):
|
| 66 |
+
fmin=self.mel[i]
|
| 67 |
+
fmax=self.mel[i+1]
|
| 68 |
+
hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate)
|
| 69 |
+
hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate)
|
| 70 |
+
hideal=hHigh-hLow
|
| 71 |
+
|
| 72 |
+
self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal)
|
| 73 |
+
|
| 74 |
+
band_pass_filter=self.band_pass.to(self.device)
|
| 75 |
+
|
| 76 |
+
self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
|
| 77 |
+
|
| 78 |
+
return F.conv1d(x, self.filters, stride=self.stride,
|
| 79 |
+
padding=self.padding, dilation=self.dilation,
|
| 80 |
+
bias=None, groups=1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Residual_block(nn.Module):
|
| 85 |
+
def __init__(self, nb_filts, first = False):
|
| 86 |
+
super(Residual_block, self).__init__()
|
| 87 |
+
self.first = first
|
| 88 |
+
|
| 89 |
+
if not self.first:
|
| 90 |
+
self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
|
| 91 |
+
|
| 92 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
|
| 93 |
+
|
| 94 |
+
self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
|
| 95 |
+
out_channels = nb_filts[1],
|
| 96 |
+
kernel_size = 3,
|
| 97 |
+
padding = 1,
|
| 98 |
+
stride = 1)
|
| 99 |
+
|
| 100 |
+
self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
|
| 101 |
+
self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
|
| 102 |
+
out_channels = nb_filts[1],
|
| 103 |
+
padding = 1,
|
| 104 |
+
kernel_size = 3,
|
| 105 |
+
stride = 1)
|
| 106 |
+
|
| 107 |
+
if nb_filts[0] != nb_filts[1]:
|
| 108 |
+
self.downsample = True
|
| 109 |
+
self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
|
| 110 |
+
out_channels = nb_filts[1],
|
| 111 |
+
padding = 0,
|
| 112 |
+
kernel_size = 1,
|
| 113 |
+
stride = 1)
|
| 114 |
+
|
| 115 |
+
else:
|
| 116 |
+
self.downsample = False
|
| 117 |
+
self.mp = nn.MaxPool1d(3)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
identity = x
|
| 121 |
+
if not self.first:
|
| 122 |
+
out = self.bn1(x)
|
| 123 |
+
out = self.lrelu(out)
|
| 124 |
+
else:
|
| 125 |
+
out = x
|
| 126 |
+
|
| 127 |
+
out = self.conv1(x)
|
| 128 |
+
out = self.bn2(out)
|
| 129 |
+
out = self.lrelu(out)
|
| 130 |
+
out = self.conv2(out)
|
| 131 |
+
|
| 132 |
+
if self.downsample:
|
| 133 |
+
identity = self.conv_downsample(identity)
|
| 134 |
+
|
| 135 |
+
out += identity
|
| 136 |
+
out = self.mp(out)
|
| 137 |
+
return out
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class RawNet(nn.Module):
|
| 144 |
+
def __init__(self, d_args, device):
|
| 145 |
+
super(RawNet, self).__init__()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
self.device=device
|
| 149 |
+
|
| 150 |
+
self.Sinc_conv=SincConv(device=self.device,
|
| 151 |
+
out_channels = d_args['filts'][0],
|
| 152 |
+
kernel_size = d_args['first_conv'],
|
| 153 |
+
in_channels = d_args['in_channels']
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
|
| 157 |
+
self.selu = nn.SELU(inplace=True)
|
| 158 |
+
self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True))
|
| 159 |
+
self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1]))
|
| 160 |
+
self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
| 161 |
+
d_args['filts'][2][0] = d_args['filts'][2][1]
|
| 162 |
+
self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
| 163 |
+
self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
| 164 |
+
self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
|
| 165 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 166 |
+
|
| 167 |
+
self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
|
| 168 |
+
l_out_features = d_args['filts'][1][-1])
|
| 169 |
+
self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
|
| 170 |
+
l_out_features = d_args['filts'][1][-1])
|
| 171 |
+
self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
| 172 |
+
l_out_features = d_args['filts'][2][-1])
|
| 173 |
+
self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
| 174 |
+
l_out_features = d_args['filts'][2][-1])
|
| 175 |
+
self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
| 176 |
+
l_out_features = d_args['filts'][2][-1])
|
| 177 |
+
self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
|
| 178 |
+
l_out_features = d_args['filts'][2][-1])
|
| 179 |
+
|
| 180 |
+
self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
|
| 181 |
+
self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
|
| 182 |
+
hidden_size = d_args['gru_node'],
|
| 183 |
+
num_layers = d_args['nb_gru_layer'],
|
| 184 |
+
batch_first = True)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
|
| 188 |
+
out_features = d_args['nb_fc_node'])
|
| 189 |
+
|
| 190 |
+
self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
|
| 191 |
+
out_features = d_args['nb_classes'],bias=True)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
self.sig = nn.Sigmoid()
|
| 195 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 196 |
+
|
| 197 |
+
def forward(self, x, y = None):
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
nb_samp = x.shape[0]
|
| 201 |
+
len_seq = x.shape[1]
|
| 202 |
+
x=x.view(nb_samp,1,len_seq)
|
| 203 |
+
|
| 204 |
+
x = self.Sinc_conv(x)
|
| 205 |
+
x = F.max_pool1d(torch.abs(x), 3)
|
| 206 |
+
x = self.first_bn(x)
|
| 207 |
+
x = self.selu(x)
|
| 208 |
+
|
| 209 |
+
x0 = self.block0(x)
|
| 210 |
+
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
|
| 211 |
+
y0 = self.fc_attention0(y0)
|
| 212 |
+
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
|
| 213 |
+
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
x1 = self.block1(x)
|
| 217 |
+
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
|
| 218 |
+
y1 = self.fc_attention1(y1)
|
| 219 |
+
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
|
| 220 |
+
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
|
| 221 |
+
|
| 222 |
+
x2 = self.block2(x)
|
| 223 |
+
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
|
| 224 |
+
y2 = self.fc_attention2(y2)
|
| 225 |
+
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
|
| 226 |
+
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
|
| 227 |
+
|
| 228 |
+
x3 = self.block3(x)
|
| 229 |
+
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
|
| 230 |
+
y3 = self.fc_attention3(y3)
|
| 231 |
+
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
|
| 232 |
+
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
|
| 233 |
+
|
| 234 |
+
x4 = self.block4(x)
|
| 235 |
+
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
|
| 236 |
+
y4 = self.fc_attention4(y4)
|
| 237 |
+
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
|
| 238 |
+
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
|
| 239 |
+
|
| 240 |
+
x5 = self.block5(x)
|
| 241 |
+
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
|
| 242 |
+
y5 = self.fc_attention5(y5)
|
| 243 |
+
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
|
| 244 |
+
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
|
| 245 |
+
|
| 246 |
+
x = self.bn_before_gru(x)
|
| 247 |
+
x = self.selu(x)
|
| 248 |
+
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
|
| 249 |
+
self.gru.flatten_parameters()
|
| 250 |
+
x, _ = self.gru(x)
|
| 251 |
+
x = x[:,-1,:]
|
| 252 |
+
x = self.fc1_gru(x)
|
| 253 |
+
x = self.fc2_gru(x)
|
| 254 |
+
output=self.logsoftmax(x)
|
| 255 |
+
print(f"Spec output shape: {output.shape}")
|
| 256 |
+
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _make_attention_fc(self, in_features, l_out_features):
|
| 262 |
+
|
| 263 |
+
l_fc = []
|
| 264 |
+
|
| 265 |
+
l_fc.append(nn.Linear(in_features = in_features,
|
| 266 |
+
out_features = l_out_features))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
return nn.Sequential(*l_fc)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _make_layer(self, nb_blocks, nb_filts, first = False):
|
| 274 |
+
layers = []
|
| 275 |
+
#def __init__(self, nb_filts, first = False):
|
| 276 |
+
for i in range(nb_blocks):
|
| 277 |
+
first = first if i == 0 else False
|
| 278 |
+
layers.append(Residual_block(nb_filts = nb_filts,
|
| 279 |
+
first = first))
|
| 280 |
+
if i == 0: nb_filts[0] = nb_filts[1]
|
| 281 |
+
|
| 282 |
+
return nn.Sequential(*layers)
|
| 283 |
+
|
| 284 |
+
def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None):
|
| 285 |
+
if print_fn == None: printfn = print
|
| 286 |
+
model = self
|
| 287 |
+
|
| 288 |
+
def register_hook(module):
|
| 289 |
+
def hook(module, input, output):
|
| 290 |
+
class_name = str(module.__class__).split(".")[-1].split("'")[0]
|
| 291 |
+
module_idx = len(summary)
|
| 292 |
+
|
| 293 |
+
m_key = "%s-%i" % (class_name, module_idx + 1)
|
| 294 |
+
summary[m_key] = OrderedDict()
|
| 295 |
+
summary[m_key]["input_shape"] = list(input[0].size())
|
| 296 |
+
summary[m_key]["input_shape"][0] = batch_size
|
| 297 |
+
if isinstance(output, (list, tuple)):
|
| 298 |
+
summary[m_key]["output_shape"] = [
|
| 299 |
+
[-1] + list(o.size())[1:] for o in output
|
| 300 |
+
]
|
| 301 |
+
else:
|
| 302 |
+
summary[m_key]["output_shape"] = list(output.size())
|
| 303 |
+
if len(summary[m_key]["output_shape"]) != 0:
|
| 304 |
+
summary[m_key]["output_shape"][0] = batch_size
|
| 305 |
+
|
| 306 |
+
params = 0
|
| 307 |
+
if hasattr(module, "weight") and hasattr(module.weight, "size"):
|
| 308 |
+
params += torch.prod(torch.LongTensor(list(module.weight.size())))
|
| 309 |
+
summary[m_key]["trainable"] = module.weight.requires_grad
|
| 310 |
+
if hasattr(module, "bias") and hasattr(module.bias, "size"):
|
| 311 |
+
params += torch.prod(torch.LongTensor(list(module.bias.size())))
|
| 312 |
+
summary[m_key]["nb_params"] = params
|
| 313 |
+
|
| 314 |
+
if (
|
| 315 |
+
not isinstance(module, nn.Sequential)
|
| 316 |
+
and not isinstance(module, nn.ModuleList)
|
| 317 |
+
and not (module == model)
|
| 318 |
+
):
|
| 319 |
+
hooks.append(module.register_forward_hook(hook))
|
| 320 |
+
|
| 321 |
+
device = device.lower()
|
| 322 |
+
assert device in [
|
| 323 |
+
"cuda",
|
| 324 |
+
"cpu",
|
| 325 |
+
], "Input device is not valid, please specify 'cuda' or 'cpu'"
|
| 326 |
+
|
| 327 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 328 |
+
dtype = torch.cuda.FloatTensor
|
| 329 |
+
else:
|
| 330 |
+
dtype = torch.FloatTensor
|
| 331 |
+
if isinstance(input_size, tuple):
|
| 332 |
+
input_size = [input_size]
|
| 333 |
+
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
|
| 334 |
+
summary = OrderedDict()
|
| 335 |
+
hooks = []
|
| 336 |
+
model.apply(register_hook)
|
| 337 |
+
model(*x)
|
| 338 |
+
for h in hooks:
|
| 339 |
+
h.remove()
|
| 340 |
+
|
| 341 |
+
print_fn("----------------------------------------------------------------")
|
| 342 |
+
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
|
| 343 |
+
print_fn(line_new)
|
| 344 |
+
print_fn("================================================================")
|
| 345 |
+
total_params = 0
|
| 346 |
+
total_output = 0
|
| 347 |
+
trainable_params = 0
|
| 348 |
+
for layer in summary:
|
| 349 |
+
# input_shape, output_shape, trainable, nb_params
|
| 350 |
+
line_new = "{:>20} {:>25} {:>15}".format(
|
| 351 |
+
layer,
|
| 352 |
+
str(summary[layer]["output_shape"]),
|
| 353 |
+
"{0:,}".format(summary[layer]["nb_params"]),
|
| 354 |
+
)
|
| 355 |
+
total_params += summary[layer]["nb_params"]
|
| 356 |
+
total_output += np.prod(summary[layer]["output_shape"])
|
| 357 |
+
if "trainable" in summary[layer]:
|
| 358 |
+
if summary[layer]["trainable"] == True:
|
| 359 |
+
trainable_params += summary[layer]["nb_params"]
|
| 360 |
+
print_fn(line_new)
|