qijie.wei
first commit
c5f4ee2
# ICCV2021, Joint Topology-preserving and Feature-refinement Network for Curvilinear Structure Segmentation
import torch
import torch.nn as nn
import torch.nn.functional as F
from .UNet_p import MultiHeadAttention2D_Dual2_2, rand, window_partition, window_unpartition, prompt_partition, OneLayerRes
class SpatialAttention(nn.Module):
def __init__(self):
super(SpatialAttention, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=(3, 3), padding=(1, 1)),
nn.Conv2d(1, 1, kernel_size=(5, 5), padding=(2, 2)),
nn.Sigmoid()
)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, channel, reduction=2):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channel, channel // reduction, 1, bias=False)
self.fc2 = nn.Conv2d(channel // reduction, channel, 1, bias=False)
self.activate = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.fc1(self.avg_pool(x)))
max_out = self.fc2(self.fc1(self.max_pool(x)))
out = avg_out + max_out
out = self.activate(out)
return out
class GAU(nn.Module):
def __init__(self, in_channels, use_gau=True, reduce_dim=False, out_channels=None):
super(GAU, self).__init__()
self.use_gau = use_gau
self.reduce_dim = reduce_dim
if self.reduce_dim:
self.down_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
in_channels = out_channels
if self.use_gau:
self.sa = SpatialAttention()
self.ca = ChannelAttention(in_channels)
self.reset_gate = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=2, dilation=2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x, y):
if self.reduce_dim:
x = self.down_conv(x)
if self.use_gau:
y = F.interpolate(y, x.shape[-2:], mode='bilinear', align_corners=True)
comx = x * y
resx = x * (1 - y) # bs, c, h, w
x_sa = self.sa(resx) # bs, 1, h, w
x_ca = self.ca(resx) # bs, c, 1, 1
O = self.reset_gate(comx)
M = x_sa * x_ca
RF = M * x + (1 - M) * O
else:
RF = x
return RF
class FIM(nn.Module):
def __init__(self, in_channels, out_channels, f_channels, use_topo=True, up=True, bottom=False):
super(FIM, self).__init__()
self.use_topo = use_topo
self.up = up
self.bottom = bottom
if self.up:
self.up_s = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.up_t = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.up_s = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.up_t = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.decoder_s = nn.Sequential(
nn.Conv2d(out_channels + f_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
'''self.inner_s = nn.Sequential(
nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
nn.Sigmoid()
)'''
if self.bottom:
self.st = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
if self.use_topo:
self.decoder_t = nn.Sequential(
nn.Conv2d(out_channels + out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.s_to_t = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.t_to_s = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.res_s = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
'''self.inner_t = nn.Sequential(
nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
nn.Sigmoid()
)'''
def forward(self, x_s, x_t, rf):
if self.use_topo:
if self.bottom:
x_t = self.st(x_t)
#bs, c, h, w = x_s.shape
x_s = self.up_s(x_s)
x_t = self.up_t(x_t)
# padding
diffY = rf.size()[2] - x_s.size()[2]
diffX = rf.size()[3] - x_s.size()[3]
x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x_t = F.pad(x_t, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
rf_s = torch.cat((x_s, rf), dim=1)
s = self.decoder_s(rf_s)
s_t = self.s_to_t(s)
t = torch.cat((x_t, s_t), dim=1)
x_t = self.decoder_t(t)
t_s = self.t_to_s(x_t)
s_res = self.res_s(torch.cat((s, t_s), dim=1))
x_s = s + s_res
# t_cls = self.inner_t(x_t)
# s_cls = self.inner_s(x_s)
else:
x_s = self.up_s(x_s)
#x_b = self.up_b(x_b)
# padding
diffY = rf.size()[2] - x_s.size()[2]
diffX = rf.size()[3] - x_s.size()[3]
x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
rf_s = torch.cat((x_s, rf), dim=1)
s = self.decoder_s(rf_s)
x_s = s
x_t = x_s
#t_cls = None
#s_cls = self.inner_s(x_s)
return x_s, x_t
class JTFNDecoder(nn.Module):
def __init__(self, channels, use_topo) -> None:
super().__init__()
self.skip_blocks = []
for i in range(5):
self.skip_blocks.append(GAU(channels[i], use_gau=True, reduce_dim=False, out_channels=channels[i]))
self.fims = []
index = 3
for i in range(4):
if i == index:
self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=True))
else:
self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=False))
self.skip_blocks = nn.ModuleList(self.skip_blocks)
self.fims = nn.ModuleList(self.fims)
def forward(self, x1, x2, x3, x4, x5, y):
x1 = self.skip_blocks[0](x1, y)
x2 = self.skip_blocks[1](x2, y)
x3 = self.skip_blocks[2](x3, y)
x4 = self.skip_blocks[3](x4, y)
x5 = self.skip_blocks[4](x5, y)
x5_seg, x5_bou = x5, x5
x4_seg, x4_bou = self.fims[3](x5_seg, x5_bou, x4)
x3_seg, x3_bou = self.fims[2](x4_seg, x4_bou, x3)
x2_seg, x2_bou = self.fims[1](x3_seg, x3_bou, x2)
x1_seg, x1_bou = self.fims[0](x2_seg, x2_bou, x1)
return [x1_seg, x2_seg, x3_seg, x4_seg], [x1_bou, x2_bou, x3_bou, x4_bou]
class JTFN(nn.Module):
def __init__(self, encoder, decoder, channels, num_classes, steps) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.num_classes = num_classes
self.steps = steps
self.conv_seg1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_seg2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_seg3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_seg4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_bou1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_bou2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_bou3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_bou4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
B, C, H, W = x.shape
y = torch.zeros([B, self.num_classes, H, W], device=x.device)
x1, x2, x3, x4, x5 = self.encoder(x)
outputs = {}
for i in range(self.steps):
segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
x1_seg, x2_seg, x3_seg, x4_seg = segs
x1_bou, x2_bou, x3_bou, x4_bou = bous
x1_seg = self.conv_seg1_head(x1_seg)
x2_seg = self.conv_seg2_head(x2_seg)
x3_seg = self.conv_seg3_head(x3_seg)
x4_seg = self.conv_seg4_head(x4_seg)
x1_bou = self.conv_bou1_head(x1_bou)
x2_bou = self.conv_bou2_head(x2_bou)
x3_bou = self.conv_bou3_head(x3_bou)
x4_bou = self.conv_bou4_head(x4_bou)
y = x1_seg
outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
y = self.upsample(y)
outputs['output'] = y
return outputs
def encoder_forward(self, x, dataset_idx):
# efficient net
x = self.encoder.conv_stem(x)
x = self.encoder.bn1(x)
features = []
if 0 in self.encoder._stage_out_idx:
features.append(x) # add stem out
for i in range(len(self.encoder.blocks)):
for j, l in enumerate(self.encoder.blocks[i]):
if j == len(self.encoder.blocks[i]) - 1 and i + 1 in self.encoder._stage_out_idx:
x = l(x, dataset_idx)
else:
x = l(x)
if i + 1 in self.encoder._stage_out_idx:
features.append(x)
return features
class JTFN_DCP(JTFN):
def __init__(self, encoder, decoder, channels, num_classes, steps, dataset_idx,
local_window_sizes, encoder_channels, pos_promot_channels, cha_promot_channels,
embed_ratio, strides, att_fusion, use_conv) -> None:
super().__init__(encoder, decoder, channels, num_classes, steps)
self.dataset_idx = dataset_idx
self.local_window_sizes = local_window_sizes
self.pos_promot_channels = pos_promot_channels
pos_p1 = rand((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]), val=3. / encoder_channels[0] ** 0.5)
pos_p2 = rand((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]), val=3. / encoder_channels[1] ** 0.5)
pos_p3 = rand((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]), val=3. / encoder_channels[2] ** 0.5)
pos_p4 = rand((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]), val=3. / encoder_channels[3] ** 0.5)
pos_p5 = rand((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]), val=3. / encoder_channels[4] ** 0.5)
self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot_channels = cha_promot_channels
cha_p1 = rand((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]), val=3. / local_window_sizes[0])
cha_p2 = rand((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]), val=3. / local_window_sizes[1])
cha_p3 = rand((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]), val=3. / local_window_sizes[2])
cha_p4 = rand((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]), val=3. / local_window_sizes[3])
cha_p5 = rand((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]), val=3. / local_window_sizes[4])
self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
self.strides = strides
self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
def get_cha_prompts(self, dataset_idx, batch_size):
if len(dataset_idx) != batch_size:
raise Exception(dataset_idx, self.dataset_idx, batch_size)
# print(dataset_idx, '***')
promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
return promots1, promots2, promots3, promots4, promots5
def get_pos_prompts(self, dataset_idx, batch_size):
if len(dataset_idx) != batch_size:
raise Exception(dataset_idx, self.dataset_idx)
# print(dataset_idx, '***')
promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
return promots1, promots2, promots3, promots4, promots5
def forward(self, x, dataset_idx, return_features=False):
if isinstance(dataset_idx, torch.Tensor):
dataset_idx = list(dataset_idx.cpu().numpy())
#print(dataset_idx)
cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
B, C, H, W = x.shape
y = torch.zeros([B, self.num_classes, H, W], device=x.device)
x1, x2, x3, x4, x5 = self.encoder(x)
if return_features:
pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
h1, w1 = x1.size()[2:]
h2, w2 = x2.size()[2:]
h3, w3 = x3.size()[2:]
h4, w4 = x4.size()[2:]
h5, w5 = x5.size()[2:]
x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
#print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
#print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
if return_features:
pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
outputs = {}
for i in range(self.steps):
segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
x1_seg, x2_seg, x3_seg, x4_seg = segs
x1_bou, x2_bou, x3_bou, x4_bou = bous
x1_seg = self.conv_seg1_head(x1_seg)
x2_seg = self.conv_seg2_head(x2_seg)
x3_seg = self.conv_seg3_head(x3_seg)
x4_seg = self.conv_seg4_head(x4_seg)
x1_bou = self.conv_bou1_head(x1_bou)
x2_bou = self.conv_bou2_head(x2_bou)
x3_bou = self.conv_bou3_head(x3_bou)
x4_bou = self.conv_bou4_head(x4_bou)
y = x1_seg
outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
y = self.upsample(y)
outputs['output'] = y
return outputs