Upload folder using huggingface_hub
Browse files- densenet/custom_loss.py +61 -0
- densenet/dense_block.py +59 -0
- densenet/densenet.py +128 -0
- densenet/layer.py +23 -0
- densenet/transitions.py +54 -0
densenet/custom_loss.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CombinedLoss(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Combined loss function that includes CrossEntropyLoss and Dice Loss.
|
| 8 |
+
The combined loss is a weighted sum of the two losses.
|
| 9 |
+
Args:
|
| 10 |
+
alpha (float): Weight for CrossEntropyLoss. The weight for Dice Loss is (1 - alpha).
|
| 11 |
+
smooth (float): Smoothing factor for Dice Loss to avoid division by zero.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, alpha=0.25, smooth=1e-8): # alpha balances the two losses
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.alpha = alpha
|
| 16 |
+
self.ce = nn.CrossEntropyLoss()
|
| 17 |
+
self.smooth = smooth
|
| 18 |
+
|
| 19 |
+
def forward(self, preds, targets):
|
| 20 |
+
loss_ce = self.ce(preds, targets)
|
| 21 |
+
loss_dice = 1-self.dice(preds, targets)
|
| 22 |
+
return self.alpha * loss_ce + (1 - self.alpha) * loss_dice
|
| 23 |
+
|
| 24 |
+
def dice_per_class(self, preds, targets):
|
| 25 |
+
"""
|
| 26 |
+
This function computes the Dice score for each slide. And outputs
|
| 27 |
+
the average Dice score for all slides.
|
| 28 |
+
Args:
|
| 29 |
+
preds (torch.Tensor): The predicted mask of shape (B, H, W).
|
| 30 |
+
targets (torch.Tensor): The ground truth mask of shape (B, H, W).
|
| 31 |
+
Returns:
|
| 32 |
+
float: The average Dice score for all slides.
|
| 33 |
+
"""
|
| 34 |
+
B, H, W = targets.shape
|
| 35 |
+
total_dice = 0
|
| 36 |
+
for i in range(B):
|
| 37 |
+
intersection = torch.sum(preds[i] * targets[i])
|
| 38 |
+
union = torch.sum(preds[i]) + torch.sum(targets[i])
|
| 39 |
+
dice = (2*intersection + self.smooth)/(union + self.smooth)
|
| 40 |
+
total_dice += dice
|
| 41 |
+
return total_dice/B
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def dice(self, preds, targets):
|
| 45 |
+
"""
|
| 46 |
+
This function computes the Dice score for each class. And outputs
|
| 47 |
+
the average Dice score for all classes.
|
| 48 |
+
Args:
|
| 49 |
+
preds (torch.Tensor): The predicted mask of shape (B, C, H, W).
|
| 50 |
+
targets (torch.Tensor): The ground truth mask of shape (B, C, H, W).
|
| 51 |
+
Returns:
|
| 52 |
+
float: The average Dice score for all classes.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
B, C, H, W = targets.shape
|
| 56 |
+
dice_for_each_class = 0
|
| 57 |
+
|
| 58 |
+
for i in range(C):
|
| 59 |
+
dice_for_each_class += self.dice_per_class(preds[:,i,:,:], targets[:,i,:,:])
|
| 60 |
+
return dice_for_each_class/C
|
| 61 |
+
|
densenet/dense_block.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from layer import Layer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DenseBlock(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Dense block for DenseNet.
|
| 10 |
+
This block consists of multiple layers where each layer's output is concatenated
|
| 11 |
+
to the input of the next layer.
|
| 12 |
+
This class was developed following the paper:
|
| 13 |
+
"The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation"
|
| 14 |
+
and the reference paper of this project.
|
| 15 |
+
Args:
|
| 16 |
+
in_channels (int): Number of input channels.
|
| 17 |
+
num_layers (int): Number of layers in the dense block.
|
| 18 |
+
growth_rate (int): Growth rate for the dense block.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, in_channels, num_layers, growth_rate):
|
| 21 |
+
super(DenseBlock, self).__init__()
|
| 22 |
+
layers = []
|
| 23 |
+
for i in range(num_layers):
|
| 24 |
+
layers.append(Layer(in_channels + i * growth_rate, growth_rate))
|
| 25 |
+
self.block = nn.Sequential(*layers)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
outputs = []
|
| 29 |
+
for layer in self.block:
|
| 30 |
+
output = layer(x)
|
| 31 |
+
outputs.append(output)
|
| 32 |
+
x = torch.cat([x, output], dim=1) # Concatenate along channel axis
|
| 33 |
+
|
| 34 |
+
# implementation from the model found in the paper: https://arxiv.org/pdf/1611.09326
|
| 35 |
+
output = torch.cat(outputs, dim=1) # Concatenate all outputs
|
| 36 |
+
return output
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class InceptionX(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
InceptionX block with three branches of different kernel sizes.
|
| 42 |
+
This is the first block of the DenseNet model.
|
| 43 |
+
Args:
|
| 44 |
+
in_channels (int): Number of input channels.
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, in_channels):
|
| 47 |
+
super(InceptionX, self).__init__()
|
| 48 |
+
# Each branch with a different padding to keep the size of the output
|
| 49 |
+
self.branch_3x3 = nn.Conv2d(in_channels, 16, kernel_size=3, padding=1, bias=False)
|
| 50 |
+
self.branch_5x5 = nn.Conv2d(in_channels, 4, kernel_size=5, padding=2, bias=False)
|
| 51 |
+
self.branch_7x7 = nn.Conv2d(in_channels, 4, kernel_size=7, padding=3, bias=False)
|
| 52 |
+
self.bn = nn.BatchNorm2d(24)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
out_3x3 = self.branch_3x3(x)
|
| 56 |
+
out_5x5 = self.branch_5x5(x)
|
| 57 |
+
out_7x7 = self.branch_7x7(x)
|
| 58 |
+
out = torch.cat([out_3x3, out_5x5, out_7x7], dim=1) # concatenate along channel axis
|
| 59 |
+
return self.bn(out)
|
densenet/densenet.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The denseNet model implementation in PyTorch is based on the paper:
|
| 2 |
+
# Densely Connected Fully Convolutional Network for Short-Axis Cardiac Cine MR Image Segmentation and Heart Diagnosis Using Random Forest
|
| 3 |
+
# https://link.springer.com/chapter/10.1007/978-3-319-75541-0_15#Tab3
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import sys
|
| 7 |
+
from scipy.ndimage import label
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
sys.path.append("./densenet") # Add the parent directory to the path
|
| 11 |
+
from dense_block import DenseBlock, InceptionX
|
| 12 |
+
from transitions import TransitionDown, TransitionUp
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DenseNet(nn.Module):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""
|
| 18 |
+
This is the DenseNet model for image segmentation based on the paper:
|
| 19 |
+
"The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation"
|
| 20 |
+
and the reference paper of this project.
|
| 21 |
+
The layers are organized as follows:
|
| 22 |
+
- Inception_X
|
| 23 |
+
- Dense Block (3 layers)
|
| 24 |
+
- Transition Down
|
| 25 |
+
- Dense Block (4 layers)
|
| 26 |
+
- Transition Down
|
| 27 |
+
- Dense Block (5 layers)
|
| 28 |
+
- Transition Down
|
| 29 |
+
- Bottleneck
|
| 30 |
+
- Transition Up
|
| 31 |
+
- Dense Block (5 layers)
|
| 32 |
+
- Transition Up
|
| 33 |
+
- Dense Block (4 layers)
|
| 34 |
+
- Transition Up
|
| 35 |
+
- Dense Block (3 layers)
|
| 36 |
+
- 1x1 convolution
|
| 37 |
+
- softmax activation
|
| 38 |
+
"""
|
| 39 |
+
super(DenseNet, self).__init__()
|
| 40 |
+
growth_rate = 8
|
| 41 |
+
|
| 42 |
+
self.inception=InceptionX(1) # output channels = 24
|
| 43 |
+
self.downdense1=DenseBlock(24, 3, growth_rate=growth_rate) # output channels = 24
|
| 44 |
+
self.td1=TransitionDown(48, 48)
|
| 45 |
+
self.downdense2=DenseBlock(48, 4, growth_rate=growth_rate) #output channels = 32
|
| 46 |
+
self.td2=TransitionDown(80, 80)
|
| 47 |
+
self.downdense3=DenseBlock(80, 5, growth_rate=growth_rate) # output channels = 40
|
| 48 |
+
self.td3=TransitionDown(120, 120)
|
| 49 |
+
self.bottleneck=DenseBlock(120, 8, growth_rate=7) # Bottleneck output channels = 56
|
| 50 |
+
self.tu1=TransitionUp(56, 56)
|
| 51 |
+
self.updense1=DenseBlock(176, 5, growth_rate=growth_rate) # output channels = 40
|
| 52 |
+
self.tu2=TransitionUp(40, 40)
|
| 53 |
+
self.updense2=DenseBlock(120, 4, growth_rate=growth_rate) # output channels = 32
|
| 54 |
+
self.tu3=TransitionUp(32, 32)
|
| 55 |
+
self.updense3=DenseBlock(80, 3, growth_rate=growth_rate) # output channels = 24
|
| 56 |
+
self.finalconv=nn.Conv2d(24, out_channels=4, kernel_size=1) # output channels = 4
|
| 57 |
+
# softmax activation
|
| 58 |
+
self.softmax = nn.Softmax(dim=1) # output channels = 4
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = self.inception(x) # size 128x128
|
| 62 |
+
x1 = self.downdense1(x)
|
| 63 |
+
x11 = torch.cat([x, x1], dim=1) # channels = 48
|
| 64 |
+
x12 = self.td1(x11)
|
| 65 |
+
x2 = self.downdense2(x12)
|
| 66 |
+
x21 = torch.cat([x12, x2], dim=1) # channels = 56
|
| 67 |
+
x22 = self.td2(x21)
|
| 68 |
+
x3 = self.downdense3(x22)
|
| 69 |
+
x31 = torch.cat([x22, x3], dim=1) # channels = 120
|
| 70 |
+
x32 = self.td3(x31)
|
| 71 |
+
x4 = self.bottleneck(x32)
|
| 72 |
+
x42 = self.tu1(x4)
|
| 73 |
+
x43 = torch.cat([x31, x42], dim=1)
|
| 74 |
+
x44 = self.updense1(x43)
|
| 75 |
+
x45 = self.tu2(x44)
|
| 76 |
+
x46 = torch.cat([x21, x45], dim=1)
|
| 77 |
+
x47 = self.updense2(x46)
|
| 78 |
+
x48 = self.tu3(x47)
|
| 79 |
+
x49 = torch.cat([x11, x48], dim=1)
|
| 80 |
+
x5 = self.updense3(x49)
|
| 81 |
+
x51 = self.finalconv(x5)
|
| 82 |
+
x52 = self.softmax(x51)
|
| 83 |
+
return x52
|
| 84 |
+
# NOTE: I´m aware that this code is a little messy with the name of the variables.
|
| 85 |
+
# However I did it by hand without LLM help, so I figure it would be nice to leave
|
| 86 |
+
# it like this to show that
|
| 87 |
+
|
| 88 |
+
def load_model(self, model_path):
|
| 89 |
+
"""
|
| 90 |
+
Load the model weights from a file.
|
| 91 |
+
Args:
|
| 92 |
+
model_path (str): Path to the model weights file.
|
| 93 |
+
"""
|
| 94 |
+
self.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
| 95 |
+
|
| 96 |
+
def get_largest_component(self, mask):
|
| 97 |
+
"""
|
| 98 |
+
This function takes a mask and returns the largest connected component of the mask.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
mask: A 3D mask (B,W,H)
|
| 102 |
+
REturns:
|
| 103 |
+
A 3D mask with only the largest connected component. (B,W,H)
|
| 104 |
+
"""
|
| 105 |
+
if len(mask.shape) != 3:
|
| 106 |
+
raise ValueError("The input mask tensor must be a 3D mask.")
|
| 107 |
+
|
| 108 |
+
output_mask = np.zeros_like(mask)
|
| 109 |
+
|
| 110 |
+
for slide in range(mask.shape[0]):
|
| 111 |
+
img = mask[slide]
|
| 112 |
+
structure = [[1,1,1],[1,1,1],[1,1,1]]
|
| 113 |
+
labeled, num_features = label(img, structure=structure)
|
| 114 |
+
|
| 115 |
+
if num_features == 0:
|
| 116 |
+
return mask # No components found
|
| 117 |
+
|
| 118 |
+
# Find the largest component
|
| 119 |
+
counts = np.bincount(labeled.flat)
|
| 120 |
+
counts[0] = 0 # Ignore background count
|
| 121 |
+
largest_label = counts.argmax()
|
| 122 |
+
|
| 123 |
+
# Create mask for the largest component
|
| 124 |
+
largest_component = (labeled == largest_label)
|
| 125 |
+
output_mask[slide] = largest_component
|
| 126 |
+
|
| 127 |
+
return output_mask
|
| 128 |
+
|
densenet/layer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class Layer(nn.Module):
|
| 4 |
+
def __init__(self, in_channels, out_channels):
|
| 5 |
+
"""
|
| 6 |
+
DenseNet layer with Batch Normalization, ELU activation,
|
| 7 |
+
Convolution, and Dropout.
|
| 8 |
+
Args:
|
| 9 |
+
in_channels (int): Number of input channels.
|
| 10 |
+
out_channels (int): Number of output channels. This is the growth rate.
|
| 11 |
+
"""
|
| 12 |
+
super(Layer, self).__init__()
|
| 13 |
+
self.block = nn.Sequential(
|
| 14 |
+
nn.BatchNorm2d(in_channels),
|
| 15 |
+
nn.ELU(inplace=True), # Exponential ReLU
|
| 16 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 17 |
+
nn.Dropout2d(p=0.2)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = self.block(x)
|
| 22 |
+
return x
|
| 23 |
+
|
densenet/transitions.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class TransitionDown(nn.Module):
|
| 4 |
+
"""
|
| 5 |
+
Transition down block for DenseNet.
|
| 6 |
+
This is the downsampling used in the first half of the network.
|
| 7 |
+
The block downsamples the input by a factor of 2 using MaxPooling.
|
| 8 |
+
Args:
|
| 9 |
+
in_channels (int): Number of input channels.
|
| 10 |
+
out_channels (int): Number of output channels.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, in_channels, out_channels):
|
| 13 |
+
super(TransitionDown, self).__init__()
|
| 14 |
+
self.block = nn.Sequential(
|
| 15 |
+
nn.BatchNorm2d(in_channels),
|
| 16 |
+
nn.ELU(inplace=True), # Exponential ReLU
|
| 17 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 18 |
+
nn.Dropout2d(p=0.2),
|
| 19 |
+
nn.MaxPool2d(kernel_size=2, stride=2) # Downsamples by 2
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return self.block(x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TransitionUp(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Transition up block for DenseNet.
|
| 29 |
+
This is the upsampling used in the second half of the network.
|
| 30 |
+
The block upsamples the input by a factor of 2 using ConvTranspose.
|
| 31 |
+
Args:
|
| 32 |
+
in_channels (int): Number of input channels.
|
| 33 |
+
out_channels (int): Number of output channels.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, in_channels, out_channels):
|
| 36 |
+
super(TransitionUp, self).__init__()
|
| 37 |
+
self.convtrans = nn.ConvTranspose2d(
|
| 38 |
+
in_channels,
|
| 39 |
+
out_channels,
|
| 40 |
+
kernel_size=3,
|
| 41 |
+
stride=2,
|
| 42 |
+
padding=1,
|
| 43 |
+
# not extremely happy with this output padding
|
| 44 |
+
# but it has to be there because otherwise the
|
| 45 |
+
# output size will be necesarily a odd number
|
| 46 |
+
# according to the formula
|
| 47 |
+
# Hout=(Hin−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1
|
| 48 |
+
# source: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d
|
| 49 |
+
output_padding=1,
|
| 50 |
+
) # Upsamples by 2
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.convtrans(x)
|
| 54 |
+
|