NicolasNoya commited on
Commit
89fbe02
·
verified ·
1 Parent(s): 8481733

Upload folder using huggingface_hub

Browse files
densenet/custom_loss.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CombinedLoss(nn.Module):
6
+ """
7
+ Combined loss function that includes CrossEntropyLoss and Dice Loss.
8
+ The combined loss is a weighted sum of the two losses.
9
+ Args:
10
+ alpha (float): Weight for CrossEntropyLoss. The weight for Dice Loss is (1 - alpha).
11
+ smooth (float): Smoothing factor for Dice Loss to avoid division by zero.
12
+ """
13
+ def __init__(self, alpha=0.25, smooth=1e-8): # alpha balances the two losses
14
+ super().__init__()
15
+ self.alpha = alpha
16
+ self.ce = nn.CrossEntropyLoss()
17
+ self.smooth = smooth
18
+
19
+ def forward(self, preds, targets):
20
+ loss_ce = self.ce(preds, targets)
21
+ loss_dice = 1-self.dice(preds, targets)
22
+ return self.alpha * loss_ce + (1 - self.alpha) * loss_dice
23
+
24
+ def dice_per_class(self, preds, targets):
25
+ """
26
+ This function computes the Dice score for each slide. And outputs
27
+ the average Dice score for all slides.
28
+ Args:
29
+ preds (torch.Tensor): The predicted mask of shape (B, H, W).
30
+ targets (torch.Tensor): The ground truth mask of shape (B, H, W).
31
+ Returns:
32
+ float: The average Dice score for all slides.
33
+ """
34
+ B, H, W = targets.shape
35
+ total_dice = 0
36
+ for i in range(B):
37
+ intersection = torch.sum(preds[i] * targets[i])
38
+ union = torch.sum(preds[i]) + torch.sum(targets[i])
39
+ dice = (2*intersection + self.smooth)/(union + self.smooth)
40
+ total_dice += dice
41
+ return total_dice/B
42
+
43
+
44
+ def dice(self, preds, targets):
45
+ """
46
+ This function computes the Dice score for each class. And outputs
47
+ the average Dice score for all classes.
48
+ Args:
49
+ preds (torch.Tensor): The predicted mask of shape (B, C, H, W).
50
+ targets (torch.Tensor): The ground truth mask of shape (B, C, H, W).
51
+ Returns:
52
+ float: The average Dice score for all classes.
53
+ """
54
+
55
+ B, C, H, W = targets.shape
56
+ dice_for_each_class = 0
57
+
58
+ for i in range(C):
59
+ dice_for_each_class += self.dice_per_class(preds[:,i,:,:], targets[:,i,:,:])
60
+ return dice_for_each_class/C
61
+
densenet/dense_block.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from layer import Layer
4
+
5
+
6
+
7
+ class DenseBlock(nn.Module):
8
+ """
9
+ Dense block for DenseNet.
10
+ This block consists of multiple layers where each layer's output is concatenated
11
+ to the input of the next layer.
12
+ This class was developed following the paper:
13
+ "The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation"
14
+ and the reference paper of this project.
15
+ Args:
16
+ in_channels (int): Number of input channels.
17
+ num_layers (int): Number of layers in the dense block.
18
+ growth_rate (int): Growth rate for the dense block.
19
+ """
20
+ def __init__(self, in_channels, num_layers, growth_rate):
21
+ super(DenseBlock, self).__init__()
22
+ layers = []
23
+ for i in range(num_layers):
24
+ layers.append(Layer(in_channels + i * growth_rate, growth_rate))
25
+ self.block = nn.Sequential(*layers)
26
+
27
+ def forward(self, x):
28
+ outputs = []
29
+ for layer in self.block:
30
+ output = layer(x)
31
+ outputs.append(output)
32
+ x = torch.cat([x, output], dim=1) # Concatenate along channel axis
33
+
34
+ # implementation from the model found in the paper: https://arxiv.org/pdf/1611.09326
35
+ output = torch.cat(outputs, dim=1) # Concatenate all outputs
36
+ return output
37
+
38
+
39
+ class InceptionX(nn.Module):
40
+ """
41
+ InceptionX block with three branches of different kernel sizes.
42
+ This is the first block of the DenseNet model.
43
+ Args:
44
+ in_channels (int): Number of input channels.
45
+ """
46
+ def __init__(self, in_channels):
47
+ super(InceptionX, self).__init__()
48
+ # Each branch with a different padding to keep the size of the output
49
+ self.branch_3x3 = nn.Conv2d(in_channels, 16, kernel_size=3, padding=1, bias=False)
50
+ self.branch_5x5 = nn.Conv2d(in_channels, 4, kernel_size=5, padding=2, bias=False)
51
+ self.branch_7x7 = nn.Conv2d(in_channels, 4, kernel_size=7, padding=3, bias=False)
52
+ self.bn = nn.BatchNorm2d(24)
53
+
54
+ def forward(self, x):
55
+ out_3x3 = self.branch_3x3(x)
56
+ out_5x5 = self.branch_5x5(x)
57
+ out_7x7 = self.branch_7x7(x)
58
+ out = torch.cat([out_3x3, out_5x5, out_7x7], dim=1) # concatenate along channel axis
59
+ return self.bn(out)
densenet/densenet.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The denseNet model implementation in PyTorch is based on the paper:
2
+ # Densely Connected Fully Convolutional Network for Short-Axis Cardiac Cine MR Image Segmentation and Heart Diagnosis Using Random Forest
3
+ # https://link.springer.com/chapter/10.1007/978-3-319-75541-0_15#Tab3
4
+ import torch
5
+ import torch.nn as nn
6
+ import sys
7
+ from scipy.ndimage import label
8
+ import numpy as np
9
+
10
+ sys.path.append("./densenet") # Add the parent directory to the path
11
+ from dense_block import DenseBlock, InceptionX
12
+ from transitions import TransitionDown, TransitionUp
13
+
14
+
15
+ class DenseNet(nn.Module):
16
+ def __init__(self):
17
+ """
18
+ This is the DenseNet model for image segmentation based on the paper:
19
+ "The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation"
20
+ and the reference paper of this project.
21
+ The layers are organized as follows:
22
+ - Inception_X
23
+ - Dense Block (3 layers)
24
+ - Transition Down
25
+ - Dense Block (4 layers)
26
+ - Transition Down
27
+ - Dense Block (5 layers)
28
+ - Transition Down
29
+ - Bottleneck
30
+ - Transition Up
31
+ - Dense Block (5 layers)
32
+ - Transition Up
33
+ - Dense Block (4 layers)
34
+ - Transition Up
35
+ - Dense Block (3 layers)
36
+ - 1x1 convolution
37
+ - softmax activation
38
+ """
39
+ super(DenseNet, self).__init__()
40
+ growth_rate = 8
41
+
42
+ self.inception=InceptionX(1) # output channels = 24
43
+ self.downdense1=DenseBlock(24, 3, growth_rate=growth_rate) # output channels = 24
44
+ self.td1=TransitionDown(48, 48)
45
+ self.downdense2=DenseBlock(48, 4, growth_rate=growth_rate) #output channels = 32
46
+ self.td2=TransitionDown(80, 80)
47
+ self.downdense3=DenseBlock(80, 5, growth_rate=growth_rate) # output channels = 40
48
+ self.td3=TransitionDown(120, 120)
49
+ self.bottleneck=DenseBlock(120, 8, growth_rate=7) # Bottleneck output channels = 56
50
+ self.tu1=TransitionUp(56, 56)
51
+ self.updense1=DenseBlock(176, 5, growth_rate=growth_rate) # output channels = 40
52
+ self.tu2=TransitionUp(40, 40)
53
+ self.updense2=DenseBlock(120, 4, growth_rate=growth_rate) # output channels = 32
54
+ self.tu3=TransitionUp(32, 32)
55
+ self.updense3=DenseBlock(80, 3, growth_rate=growth_rate) # output channels = 24
56
+ self.finalconv=nn.Conv2d(24, out_channels=4, kernel_size=1) # output channels = 4
57
+ # softmax activation
58
+ self.softmax = nn.Softmax(dim=1) # output channels = 4
59
+
60
+ def forward(self, x):
61
+ x = self.inception(x) # size 128x128
62
+ x1 = self.downdense1(x)
63
+ x11 = torch.cat([x, x1], dim=1) # channels = 48
64
+ x12 = self.td1(x11)
65
+ x2 = self.downdense2(x12)
66
+ x21 = torch.cat([x12, x2], dim=1) # channels = 56
67
+ x22 = self.td2(x21)
68
+ x3 = self.downdense3(x22)
69
+ x31 = torch.cat([x22, x3], dim=1) # channels = 120
70
+ x32 = self.td3(x31)
71
+ x4 = self.bottleneck(x32)
72
+ x42 = self.tu1(x4)
73
+ x43 = torch.cat([x31, x42], dim=1)
74
+ x44 = self.updense1(x43)
75
+ x45 = self.tu2(x44)
76
+ x46 = torch.cat([x21, x45], dim=1)
77
+ x47 = self.updense2(x46)
78
+ x48 = self.tu3(x47)
79
+ x49 = torch.cat([x11, x48], dim=1)
80
+ x5 = self.updense3(x49)
81
+ x51 = self.finalconv(x5)
82
+ x52 = self.softmax(x51)
83
+ return x52
84
+ # NOTE: I´m aware that this code is a little messy with the name of the variables.
85
+ # However I did it by hand without LLM help, so I figure it would be nice to leave
86
+ # it like this to show that
87
+
88
+ def load_model(self, model_path):
89
+ """
90
+ Load the model weights from a file.
91
+ Args:
92
+ model_path (str): Path to the model weights file.
93
+ """
94
+ self.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
95
+
96
+ def get_largest_component(self, mask):
97
+ """
98
+ This function takes a mask and returns the largest connected component of the mask.
99
+
100
+ Args:
101
+ mask: A 3D mask (B,W,H)
102
+ REturns:
103
+ A 3D mask with only the largest connected component. (B,W,H)
104
+ """
105
+ if len(mask.shape) != 3:
106
+ raise ValueError("The input mask tensor must be a 3D mask.")
107
+
108
+ output_mask = np.zeros_like(mask)
109
+
110
+ for slide in range(mask.shape[0]):
111
+ img = mask[slide]
112
+ structure = [[1,1,1],[1,1,1],[1,1,1]]
113
+ labeled, num_features = label(img, structure=structure)
114
+
115
+ if num_features == 0:
116
+ return mask # No components found
117
+
118
+ # Find the largest component
119
+ counts = np.bincount(labeled.flat)
120
+ counts[0] = 0 # Ignore background count
121
+ largest_label = counts.argmax()
122
+
123
+ # Create mask for the largest component
124
+ largest_component = (labeled == largest_label)
125
+ output_mask[slide] = largest_component
126
+
127
+ return output_mask
128
+
densenet/layer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class Layer(nn.Module):
4
+ def __init__(self, in_channels, out_channels):
5
+ """
6
+ DenseNet layer with Batch Normalization, ELU activation,
7
+ Convolution, and Dropout.
8
+ Args:
9
+ in_channels (int): Number of input channels.
10
+ out_channels (int): Number of output channels. This is the growth rate.
11
+ """
12
+ super(Layer, self).__init__()
13
+ self.block = nn.Sequential(
14
+ nn.BatchNorm2d(in_channels),
15
+ nn.ELU(inplace=True), # Exponential ReLU
16
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
17
+ nn.Dropout2d(p=0.2)
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.block(x)
22
+ return x
23
+
densenet/transitions.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class TransitionDown(nn.Module):
4
+ """
5
+ Transition down block for DenseNet.
6
+ This is the downsampling used in the first half of the network.
7
+ The block downsamples the input by a factor of 2 using MaxPooling.
8
+ Args:
9
+ in_channels (int): Number of input channels.
10
+ out_channels (int): Number of output channels.
11
+ """
12
+ def __init__(self, in_channels, out_channels):
13
+ super(TransitionDown, self).__init__()
14
+ self.block = nn.Sequential(
15
+ nn.BatchNorm2d(in_channels),
16
+ nn.ELU(inplace=True), # Exponential ReLU
17
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
18
+ nn.Dropout2d(p=0.2),
19
+ nn.MaxPool2d(kernel_size=2, stride=2) # Downsamples by 2
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.block(x)
24
+
25
+
26
+ class TransitionUp(nn.Module):
27
+ """
28
+ Transition up block for DenseNet.
29
+ This is the upsampling used in the second half of the network.
30
+ The block upsamples the input by a factor of 2 using ConvTranspose.
31
+ Args:
32
+ in_channels (int): Number of input channels.
33
+ out_channels (int): Number of output channels.
34
+ """
35
+ def __init__(self, in_channels, out_channels):
36
+ super(TransitionUp, self).__init__()
37
+ self.convtrans = nn.ConvTranspose2d(
38
+ in_channels,
39
+ out_channels,
40
+ kernel_size=3,
41
+ stride=2,
42
+ padding=1,
43
+ # not extremely happy with this output padding
44
+ # but it has to be there because otherwise the
45
+ # output size will be necesarily a odd number
46
+ # according to the formula
47
+ # Hout​=(Hin​−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1
48
+ # source: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d
49
+ output_padding=1,
50
+ ) # Upsamples by 2
51
+
52
+ def forward(self, x):
53
+ return self.convtrans(x)
54
+