# ICCV2021, Joint Topology-preserving and Feature-refinement Network for Curvilinear Structure Segmentation import torch import torch.nn as nn import torch.nn.functional as F from .UNet_p import MultiHeadAttention2D_Dual2_2, rand, window_partition, window_unpartition, prompt_partition, OneLayerRes class SpatialAttention(nn.Module): def __init__(self): super(SpatialAttention, self).__init__() self.conv = nn.Sequential( nn.Conv2d(2, 1, kernel_size=(3, 3), padding=(1, 1)), nn.Conv2d(1, 1, kernel_size=(5, 5), padding=(2, 2)), nn.Sigmoid() ) def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return x class ChannelAttention(nn.Module): def __init__(self, channel, reduction=2): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(channel, channel // reduction, 1, bias=False) self.fc2 = nn.Conv2d(channel // reduction, channel, 1, bias=False) self.activate = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.fc1(self.avg_pool(x))) max_out = self.fc2(self.fc1(self.max_pool(x))) out = avg_out + max_out out = self.activate(out) return out class GAU(nn.Module): def __init__(self, in_channels, use_gau=True, reduce_dim=False, out_channels=None): super(GAU, self).__init__() self.use_gau = use_gau self.reduce_dim = reduce_dim if self.reduce_dim: self.down_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) in_channels = out_channels if self.use_gau: self.sa = SpatialAttention() self.ca = ChannelAttention(in_channels) self.reset_gate = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=2, dilation=2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x, y): if self.reduce_dim: x = self.down_conv(x) if self.use_gau: y = F.interpolate(y, x.shape[-2:], mode='bilinear', align_corners=True) comx = x * y resx = x * (1 - y) # bs, c, h, w x_sa = self.sa(resx) # bs, 1, h, w x_ca = self.ca(resx) # bs, c, 1, 1 O = self.reset_gate(comx) M = x_sa * x_ca RF = M * x + (1 - M) * O else: RF = x return RF class FIM(nn.Module): def __init__(self, in_channels, out_channels, f_channels, use_topo=True, up=True, bottom=False): super(FIM, self).__init__() self.use_topo = use_topo self.up = up self.bottom = bottom if self.up: self.up_s = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.up_t = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) else: self.up_s = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.up_t = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.decoder_s = nn.Sequential( nn.Conv2d(out_channels + f_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) '''self.inner_s = nn.Sequential( nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False), nn.Sigmoid() )''' if self.bottom: self.st = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True) ) if self.use_topo: self.decoder_t = nn.Sequential( nn.Conv2d(out_channels + out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.s_to_t = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.t_to_s = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.res_s = nn.Sequential( nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) '''self.inner_t = nn.Sequential( nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False), nn.Sigmoid() )''' def forward(self, x_s, x_t, rf): if self.use_topo: if self.bottom: x_t = self.st(x_t) #bs, c, h, w = x_s.shape x_s = self.up_s(x_s) x_t = self.up_t(x_t) # padding diffY = rf.size()[2] - x_s.size()[2] diffX = rf.size()[3] - x_s.size()[3] x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x_t = F.pad(x_t, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) rf_s = torch.cat((x_s, rf), dim=1) s = self.decoder_s(rf_s) s_t = self.s_to_t(s) t = torch.cat((x_t, s_t), dim=1) x_t = self.decoder_t(t) t_s = self.t_to_s(x_t) s_res = self.res_s(torch.cat((s, t_s), dim=1)) x_s = s + s_res # t_cls = self.inner_t(x_t) # s_cls = self.inner_s(x_s) else: x_s = self.up_s(x_s) #x_b = self.up_b(x_b) # padding diffY = rf.size()[2] - x_s.size()[2] diffX = rf.size()[3] - x_s.size()[3] x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) rf_s = torch.cat((x_s, rf), dim=1) s = self.decoder_s(rf_s) x_s = s x_t = x_s #t_cls = None #s_cls = self.inner_s(x_s) return x_s, x_t class JTFNDecoder(nn.Module): def __init__(self, channels, use_topo) -> None: super().__init__() self.skip_blocks = [] for i in range(5): self.skip_blocks.append(GAU(channels[i], use_gau=True, reduce_dim=False, out_channels=channels[i])) self.fims = [] index = 3 for i in range(4): if i == index: self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=True)) else: self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=False)) self.skip_blocks = nn.ModuleList(self.skip_blocks) self.fims = nn.ModuleList(self.fims) def forward(self, x1, x2, x3, x4, x5, y): x1 = self.skip_blocks[0](x1, y) x2 = self.skip_blocks[1](x2, y) x3 = self.skip_blocks[2](x3, y) x4 = self.skip_blocks[3](x4, y) x5 = self.skip_blocks[4](x5, y) x5_seg, x5_bou = x5, x5 x4_seg, x4_bou = self.fims[3](x5_seg, x5_bou, x4) x3_seg, x3_bou = self.fims[2](x4_seg, x4_bou, x3) x2_seg, x2_bou = self.fims[1](x3_seg, x3_bou, x2) x1_seg, x1_bou = self.fims[0](x2_seg, x2_bou, x1) return [x1_seg, x2_seg, x3_seg, x4_seg], [x1_bou, x2_bou, x3_bou, x4_bou] class JTFN(nn.Module): def __init__(self, encoder, decoder, channels, num_classes, steps) -> None: super().__init__() self.encoder = encoder self.decoder = decoder self.num_classes = num_classes self.steps = steps self.conv_seg1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_seg2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_seg3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_seg4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_bou1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_bou2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_bou3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.conv_bou4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) def forward(self, x): B, C, H, W = x.shape y = torch.zeros([B, self.num_classes, H, W], device=x.device) x1, x2, x3, x4, x5 = self.encoder(x) outputs = {} for i in range(self.steps): segs, bous = self.decoder(x1, x2, x3, x4, x5, y) x1_seg, x2_seg, x3_seg, x4_seg = segs x1_bou, x2_bou, x3_bou, x4_bou = bous x1_seg = self.conv_seg1_head(x1_seg) x2_seg = self.conv_seg2_head(x2_seg) x3_seg = self.conv_seg3_head(x3_seg) x4_seg = self.conv_seg4_head(x4_seg) x1_bou = self.conv_bou1_head(x1_bou) x2_bou = self.conv_bou2_head(x2_bou) x3_bou = self.conv_bou3_head(x3_bou) x4_bou = self.conv_bou4_head(x4_bou) y = x1_seg outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg] outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou] y = self.upsample(y) outputs['output'] = y return outputs def encoder_forward(self, x, dataset_idx): # efficient net x = self.encoder.conv_stem(x) x = self.encoder.bn1(x) features = [] if 0 in self.encoder._stage_out_idx: features.append(x) # add stem out for i in range(len(self.encoder.blocks)): for j, l in enumerate(self.encoder.blocks[i]): if j == len(self.encoder.blocks[i]) - 1 and i + 1 in self.encoder._stage_out_idx: x = l(x, dataset_idx) else: x = l(x) if i + 1 in self.encoder._stage_out_idx: features.append(x) return features class JTFN_DCP(JTFN): def __init__(self, encoder, decoder, channels, num_classes, steps, dataset_idx, local_window_sizes, encoder_channels, pos_promot_channels, cha_promot_channels, embed_ratio, strides, att_fusion, use_conv) -> None: super().__init__(encoder, decoder, channels, num_classes, steps) self.dataset_idx = dataset_idx self.local_window_sizes = local_window_sizes self.pos_promot_channels = pos_promot_channels pos_p1 = rand((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]), val=3. / encoder_channels[0] ** 0.5) pos_p2 = rand((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]), val=3. / encoder_channels[1] ** 0.5) pos_p3 = rand((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]), val=3. / encoder_channels[2] ** 0.5) pos_p4 = rand((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]), val=3. / encoder_channels[3] ** 0.5) pos_p5 = rand((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]), val=3. / encoder_channels[4] ** 0.5) self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx}) self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot_channels = cha_promot_channels cha_p1 = rand((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]), val=3. / local_window_sizes[0]) cha_p2 = rand((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]), val=3. / local_window_sizes[1]) cha_p3 = rand((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]), val=3. / local_window_sizes[2]) cha_p4 = rand((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]), val=3. / local_window_sizes[3]) cha_p5 = rand((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]), val=3. / local_window_sizes[4]) self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx}) self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx}) self.strides = strides self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv) self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv) def get_cha_prompts(self, dataset_idx, batch_size): if len(dataset_idx) != batch_size: raise Exception(dataset_idx, self.dataset_idx, batch_size) # print(dataset_idx, '***') promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0) promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0) promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0) promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0) promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0) return promots1, promots2, promots3, promots4, promots5 def get_pos_prompts(self, dataset_idx, batch_size): if len(dataset_idx) != batch_size: raise Exception(dataset_idx, self.dataset_idx) # print(dataset_idx, '***') promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0) promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0) promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0) promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0) promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0) return promots1, promots2, promots3, promots4, promots5 def forward(self, x, dataset_idx, return_features=False): if isinstance(dataset_idx, torch.Tensor): dataset_idx = list(dataset_idx.cpu().numpy()) #print(dataset_idx) cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0)) pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0)) B, C, H, W = x.shape y = torch.zeros([B, self.num_classes, H, W], device=x.device) x1, x2, x3, x4, x5 = self.encoder(x) if return_features: pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone() h1, w1 = x1.size()[2:] h2, w2 = x2.size()[2:] h3, w3 = x3.size()[2:] h4, w4 = x4.size()[2:] h5, w5 = x5.size()[2:] x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0]) x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1]) x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2]) x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3]) x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4]) cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1) cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2) cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3) cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4) cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5) pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1) pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2) pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3) pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4) pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5) #print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size()) cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1) pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2) #print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size()) x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5) x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1)) x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2)) x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3)) x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4)) x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5)) if return_features: pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone() return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5) outputs = {} for i in range(self.steps): segs, bous = self.decoder(x1, x2, x3, x4, x5, y) x1_seg, x2_seg, x3_seg, x4_seg = segs x1_bou, x2_bou, x3_bou, x4_bou = bous x1_seg = self.conv_seg1_head(x1_seg) x2_seg = self.conv_seg2_head(x2_seg) x3_seg = self.conv_seg3_head(x3_seg) x4_seg = self.conv_seg4_head(x4_seg) x1_bou = self.conv_bou1_head(x1_bou) x2_bou = self.conv_bou2_head(x2_bou) x3_bou = self.conv_bou3_head(x3_bou) x4_bou = self.conv_bou4_head(x4_bou) y = x1_seg outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg] outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou] y = self.upsample(y) outputs['output'] = y return outputs