qijie.wei commited on
Commit
c5f4ee2
·
1 Parent(s): 55cc90a

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ data*
3
+ checkpoints
4
+ *.pyc
5
+ output_images
6
+ *.pkl
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
+ Demo for our ICASSP 2025 paper [Convolutional Prompting for Broad-Domain Retinal Vessel Segmentation](https://arxiv.org/abs/2412.18089).
14
+ Please refer to [https://github.com/ruc-aimc-lab/dcp](https://github.com/ruc-aimc-lab/dcp) for more information.
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import Inference
3
+ import os
4
+ from huggingface_hub import snapshot_download
5
+
6
+ #MODEL_ID = os.getenv("MODEL_ID", "your_username/your_model_name") # 替换为你的模型ID
7
+
8
+ model_path = snapshot_download(repo_id='AIMClab-RUC/UNet_DCP_1024')
9
+
10
+ TEXT_OPTIONS = ["CFP", "UWF", "FFA", "SLO", "OCTA"]
11
+
12
+ inference_engine = Inference(model_path=model_path)
13
+
14
+ def main(image, text):
15
+ out = inference_engine.inference(image, text)
16
+ return out
17
+
18
+ interface = gr.Interface(
19
+ fn=main,
20
+ inputs=[
21
+ gr.Image(type="numpy"),
22
+ gr.Dropdown(
23
+ choices=TEXT_OPTIONS,
24
+ label="Modality",
25
+ value=TEXT_OPTIONS[0]
26
+ )
27
+ ],
28
+ outputs=gr.Image(type="numpy"),
29
+ title="Broad domain retinal vessel segmentation",
30
+ description=""
31
+ )
32
+
33
+ interface.launch()
inference.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example code for running inference on a pre-trained model
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ from models import build_model
8
+
9
+
10
+ # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
13
+
14
+ def sigmoid(arr):
15
+ return 1. / (1 + np.exp(-arr))
16
+
17
+ class Inference(object):
18
+ def __init__(self, model_path):
19
+ self.model_path = model_path
20
+ config_path = os.path.join(model_path, 'config.json')
21
+ with open(config_path) as fin:
22
+ params = json.load(fin)
23
+ self.model_params = params['model_params']
24
+ self.modality_mapping = params['modality_mapping']
25
+ self.model = self.load_model()
26
+
27
+
28
+ def inference(self, image, modality):
29
+ assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
30
+
31
+ image = self.load_image(image)
32
+ modality_idx = self.modality_mapping[modality]
33
+ modality_idx = torch.tensor([modality_idx])
34
+ with torch.no_grad():
35
+ output = self.model.predict(x=image, device=device, dataset_idx=modality_idx)
36
+ output = output.data.cpu().numpy()[0][0]
37
+ output = sigmoid(output) * 255
38
+ output = output.astype(np.uint8)
39
+ return output
40
+
41
+ def load_image(self, image):
42
+ # Load the image and preprocess it
43
+ if isinstance(image, str):
44
+ image = cv2.imread(image)[:, :, [2, 1, 0]]
45
+ #image = image
46
+ image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
47
+ image = image.astype(np.float32) / 255.0
48
+ image = np.transpose(image, (2, 0, 1))
49
+ image = np.expand_dims(image, axis=0)
50
+ image = torch.tensor(image)
51
+ return image
52
+
53
+ def load_model(self):
54
+ print('Loading model from {}'.format(self.model_path))
55
+ model = build_model(model_name=self.model_params['net'],
56
+ model_params=self.model_params,
57
+ training=False,
58
+ dataset_idx=list(self.modality_mapping.values()),
59
+ pretrained=False)
60
+ #print(model.model.pos_promot3['0'])
61
+
62
+ model.set_device(device)
63
+ # model.requires_grad_false()
64
+ model.load_model(os.path.join(self.model_path, 'model.pkl'))
65
+ model.set_mode('eval')
66
+
67
+ return model
68
+
69
+
70
+ if __name__ == '__main__':
71
+ model_path = 'checkpoints/UNet_DCP_1024'
72
+ image_paths = [
73
+ 'images/FFA.bmp',
74
+ 'images/CFP.jpg',
75
+ 'images/SLO.jpg',
76
+ 'images/UWF.jpg',
77
+ 'images/OCTA.png'
78
+ ]
79
+ modalities = ['FFA', 'CFP', 'SLO', 'UWF', 'OCTA']
80
+
81
+ output_root = 'output_images'
82
+ os.makedirs(output_root, exist_ok=True)
83
+
84
+ inference = Inference(model_path)
85
+
86
+ for image_path, modality in zip(image_paths, modalities):
87
+ output = inference.inference(image_path, modality)
88
+ cv2.imwrite(os.path.join(output_root, '{}.png'.format(modality)), output)
models/UNet_p.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple
5
+
6
+
7
+ def rand(size, val=0.01):
8
+ out = torch.zeros(size)
9
+
10
+ nn.init.uniform_(out, -val, val)
11
+ return out
12
+
13
+ # from medsam
14
+ def window_partition(x: torch.Tensor, window_size: int):
15
+ B, C, H, W = x.size()
16
+ pad_h = (window_size - H % window_size) % window_size
17
+ pad_w = (window_size - W % window_size) % window_size
18
+ if pad_h > 0 or pad_w > 0:
19
+ x = F.pad(x, (0, pad_w, 0, pad_h))
20
+ Hp, Wp = H + pad_h, W + pad_w
21
+
22
+ x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
23
+ windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
24
+ return windows, (Hp, Wp), (Hp // window_size, Wp // window_size)
25
+
26
+ def prompt_partition(prompt: torch.Tensor, h_windows: int, w_windows: int):
27
+ # prompt: B, C, H, W
28
+ B, C, H, W = prompt.size()
29
+ prompt = prompt.view(B, 1, 1, C, H, W)
30
+ prompt = prompt.repeat((1, h_windows, w_windows, 1, 1, 1)).contiguous().view(-1, C, H, W)
31
+ return prompt
32
+
33
+ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]):
34
+ # windows: B * Hp // window_size * Wp // window_size, C, window_size, window_size
35
+ Hp, Wp = pad_hw
36
+ H, W = hw
37
+ B = (windows.shape[0] * window_size * window_size) // (Hp * Wp)
38
+ # 0 1 2 3 4 5
39
+ x = windows.view(B, Hp // window_size, Wp // window_size, -1, window_size, window_size)
40
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, Hp, Wp)
41
+
42
+ if Hp > H or Wp > W:
43
+ x = x[:, :, :H, :W].contiguous()
44
+ return x
45
+
46
+
47
+ class GELU(nn.Module):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def forward(self, x):
52
+ cdf = 0.5 * (1 + torch.erf(x / 2**0.5))
53
+ return x * cdf
54
+
55
+
56
+ class OneLayerRes(nn.Module):
57
+ def __init__(self, in_features, out_features, kernel_size, padding) -> None:
58
+ super().__init__()
59
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding)
60
+ self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
61
+
62
+ def forward(self, x):
63
+ x = x + self.weight * self.conv(x)
64
+ return x
65
+
66
+
67
+ class MLP(nn.Module):
68
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
69
+ super().__init__()
70
+ out_features = out_features or in_features
71
+ hidden_features = hidden_features or in_features
72
+ self.fc1 = nn.Linear(in_features, hidden_features)
73
+ self.act = act_layer()
74
+ self.fc2 = nn.Linear(hidden_features, out_features)
75
+ self.drop = nn.Dropout(drop)
76
+
77
+ def forward(self, x):
78
+ x = self.fc1(x)
79
+ x = self.act(x)
80
+ x = self.drop(x)
81
+ x = self.fc2(x)
82
+ return x
83
+
84
+
85
+ class MultiHeadSelfAttention(nn.Module):
86
+ def __init__(self, dim, num_heads=8, drop_rate=0.2):
87
+ super().__init__()
88
+ self.num_heads = num_heads
89
+ head_dim = dim // num_heads
90
+ self.norm = nn.LayerNorm(dim)
91
+
92
+ self.scale = head_dim ** -0.5
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
95
+ self.drop = nn.Dropout(drop_rate)
96
+ self.proj = nn.Linear(dim, dim)
97
+
98
+ def forward(self, x, heat=False):
99
+ B, N, C = x.shape
100
+ out = self.norm(x)
101
+ qkv = self.qkv(out).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
102
+ q, k, v = qkv[0], qkv[1], qkv[2]
103
+
104
+ attn = (q @ k.transpose(-2, -1)) * self.scale
105
+ attn = attn.softmax(dim=-1)
106
+ attn = self.drop(attn)
107
+
108
+ out = (attn @ v).transpose(1, 2).reshape(B, N, C)
109
+ out = self.proj(out)
110
+ out = self.drop(out)
111
+ out = x + out
112
+ if heat:
113
+ return out, attn
114
+ return out
115
+
116
+
117
+ class MultiHeadAttention2D_POS(nn.Module):
118
+ def __init__(self, dim_q, dim_k, dim_v, embed_dim, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, slide=0):
119
+ super().__init__()
120
+ self.stride = stride
121
+ self.num_heads = num_heads
122
+
123
+ self.slide = slide
124
+
125
+ self.embed_dim_qk = embed_dim // embed_dim_ratio
126
+
127
+ if self.embed_dim_qk % num_heads != 0:
128
+ self.embed_dim_qk = (self.embed_dim_qk // num_heads + 1) * num_heads
129
+
130
+ self.embed_dim_v = embed_dim
131
+ if self.embed_dim_v % num_heads != 0:
132
+ self.embed_dim_v = (self.embed_dim_v // num_heads + 1) * num_heads
133
+
134
+ head_dim = self.embed_dim_qk // num_heads
135
+
136
+ self.scale = head_dim ** -0.5
137
+
138
+ self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
139
+ self.conv_k = nn.Conv2d(in_channels=dim_k, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
140
+ self.conv_v = nn.Conv2d(in_channels=dim_v, out_channels=self.embed_dim_v, kernel_size=stride, padding=0, stride=stride)
141
+
142
+ self.drop = nn.Dropout(drop_rate)
143
+ self.proj_out = nn.Conv2d(in_channels=self.embed_dim_v, out_channels=dim_q, kernel_size=3, padding=1)
144
+ if self.stride > 1:
145
+ self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
146
+ else:
147
+ self.upsample = nn.Identity()
148
+
149
+ self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
150
+
151
+ def forward(self, q, k, v, heat=False):
152
+ B, _, H_q, W_q = q.size()
153
+ _, _, H_kv, W_kv = k.size()
154
+
155
+ H_q = H_q // self.stride
156
+ W_q = W_q // self.stride
157
+ H_kv = H_kv // self.stride
158
+ W_kv = W_kv // self.stride
159
+
160
+ proj_q = self.conv_q(q).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_q * W_q).permute(0, 1, 3, 2).contiguous()
161
+ proj_k = self.conv_k(k).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
162
+ proj_v = self.conv_v(v).reshape(B, self.num_heads, self.embed_dim_v // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
163
+
164
+ attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * self.scale # B, self.num_heads, H_q * W_q, H_kv * W_kv
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.drop(attn)
167
+
168
+ out = (attn @ proj_v) # B, self.num_heads, H_q * W_q, self.embed_dim // self.num_heads
169
+ out = out.transpose(2, 3).contiguous().reshape(B, self.embed_dim_v, H_q, W_q)
170
+
171
+ if self.slide > 0:
172
+ out = out[:, :, self.slide // self.stride:]
173
+ q = q[:, :, self.slide:]
174
+
175
+ out = self.proj_out(out)
176
+ out = self.upsample(out)
177
+ out = self.drop(out)
178
+ out = q + out * self.gamma
179
+ return out
180
+
181
+
182
+ class MultiHeadAttention2D_CHA(nn.Module):
183
+ def __init__(self, dim_q, dim_kv, stride, num_heads=8, drop_rate=0.2, slide=0):
184
+ super().__init__()
185
+ self.num_heads = num_heads
186
+ self.stride = stride
187
+ self.slide = slide
188
+ self.dim_q_out = dim_q - slide
189
+
190
+ self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=dim_q * num_heads, kernel_size=stride, stride=stride, groups=dim_q)
191
+ self.conv_k = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
192
+ self.conv_v = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
193
+
194
+
195
+ self.drop = nn.Dropout(drop_rate)
196
+ self.proj_out = nn.ConvTranspose2d(in_channels=self.dim_q_out * num_heads, out_channels=self.dim_q_out, kernel_size=stride, stride=stride, groups=self.dim_q_out)
197
+ self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
198
+
199
+ def forward(self, q, k, v, heat=False):
200
+ B, C_q, H_q, W_q = q.size()
201
+ _, C_kv, H_kv, W_kv = k.size()
202
+
203
+ proj_q = self.conv_q(q).reshape(B, self.num_heads, C_q, -1) # batch_size * num_heads * dim_q * (H * W)
204
+ proj_k = self.conv_k(k).reshape(B, self.num_heads, C_kv, -1)
205
+ proj_v = self.conv_v(v).reshape(B, self.num_heads, C_kv, -1) # batch_size * num_heads * dim_kv * (H * W)
206
+
207
+ scale = proj_q.size(3) ** -0.5
208
+ attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * scale # batch_size, num_heads, dim_q, dim_kv
209
+ attn = attn.softmax(dim=-1)
210
+ attn = self.drop(attn)
211
+
212
+ out = (attn @ proj_v) # batch_size, num_heads, dim_q, (H * W)
213
+ if self.slide > 0:
214
+ out = out[:, :, :-self.slide]
215
+ out = out.reshape(B, self.num_heads * self.dim_q_out, H_q // self.stride, W_q // self.stride)
216
+
217
+ out = self.proj_out(out)
218
+ out = self.drop(out)
219
+ out = q + out * self.gamma
220
+ return out
221
+
222
+
223
+ class MultiHeadAttention2D_Dual2_2(nn.Module):
224
+ def __init__(self, dim_pos, dim_cha, embed_dim, att_fusion, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, cha_slide=0, pos_slide=0, use_conv=True):
225
+ super().__init__()
226
+ self.pos_att = MultiHeadAttention2D_POS(dim_q=dim_pos, dim_k=dim_pos, dim_v=dim_pos, embed_dim=embed_dim, num_heads=num_heads, drop_rate=drop_rate, embed_dim_ratio=embed_dim_ratio, stride=stride, slide=pos_slide)
227
+ self.cha_att = MultiHeadAttention2D_CHA(dim_q=dim_cha, dim_kv=dim_cha, num_heads=num_heads, drop_rate=drop_rate, slide=cha_slide, stride=stride)
228
+ self.att_fusion = att_fusion # concat, add
229
+
230
+ if att_fusion == 'concat':
231
+ channel_in = 2 * (dim_pos - cha_slide)
232
+ if att_fusion == 'add':
233
+ channel_in = (dim_pos - cha_slide)
234
+ channel_out = dim_pos - cha_slide
235
+
236
+ self.use_conv = use_conv
237
+ if use_conv:
238
+ self.conv_out = nn.Sequential(nn.Dropout2d(drop_rate, True), nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1))
239
+ else:
240
+ self.conv_out = nn.Identity()
241
+
242
+
243
+ def forward(self, qkv_pos, qkv_cha, heat=False):
244
+ if qkv_cha is None:
245
+ qkv_cha = qkv_pos
246
+ out_pos = self.pos_att(qkv_pos, qkv_pos, qkv_pos, heat)
247
+ out_cha = self.cha_att(qkv_cha, qkv_cha, qkv_cha, heat)
248
+
249
+ C = out_pos.size(1)
250
+ H = out_cha.size(2)
251
+
252
+ if self.att_fusion == 'concat':
253
+ out = torch.cat([out_pos[:, :, -H:], out_cha[:, :C, :]], dim=1)
254
+ if self.att_fusion == 'add':
255
+ out = (out_pos[:, :, -H:] + out_cha[:, :C, :]) / 2
256
+
257
+ out = self.conv_out(out)
258
+ return out
259
+
260
+
261
+
262
+ class ResMLP(MLP):
263
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
264
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, act_layer=act_layer, drop=drop)
265
+ self.norm = nn.LayerNorm(in_features)
266
+
267
+ def forward(self, x):
268
+ out = self.norm(x)
269
+ out = self.fc1(out)
270
+ out = self.act(out)
271
+ out = self.drop(out)
272
+ out = self.fc2(out)
273
+ out = out + x
274
+ return out
275
+
276
+
277
+ class MHSABlock(nn.Module):
278
+ def __init__(self, dim, num_heads=8, drop_rate=0.2) -> None:
279
+ super().__init__()
280
+ self.mhsa = MultiHeadSelfAttention(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
281
+ self.mlp = ResMLP(in_features=dim, hidden_features=dim*4, out_features=dim)
282
+
283
+ def forward(self, x, heat=False):
284
+
285
+ if heat:
286
+ x, attn = self.mhsa(x, heat=True)
287
+ else:
288
+ x = self.mhsa(x)
289
+ x = self.mlp(x)
290
+ if heat:
291
+ return x, attn
292
+ return x
293
+
294
+
295
+ class SelfAttentionBlocks(nn.Module):
296
+ def __init__(self, dim, block_num, num_heads=8, drop_rate=0.2):
297
+ super().__init__()
298
+ self.block_num = block_num
299
+ assert self.block_num >= 1
300
+
301
+ self.blocks = nn.ModuleList([MHSABlock(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
302
+ for i in range(self.block_num)])
303
+
304
+ def forward(self, x, heat=False):
305
+ attns = []
306
+ for blk in self.blocks:
307
+ if heat:
308
+ x, attn = blk(x, heat=True)
309
+ attns.append(attn)
310
+ else:
311
+ x = blk(x)
312
+ if heat:
313
+ return x, attns
314
+ return x
315
+
316
+
317
+ class conv_block(nn.Module):
318
+ def __init__(self,ch_in,ch_out):
319
+ super(conv_block,self).__init__()
320
+ self.conv = nn.Sequential(
321
+ nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
322
+ nn.BatchNorm2d(ch_out),
323
+ nn.ReLU(inplace=True),
324
+ nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
325
+ nn.BatchNorm2d(ch_out),
326
+ nn.ReLU(inplace=True)
327
+ )
328
+
329
+
330
+ def forward(self,x):
331
+ x = self.conv(x)
332
+ return x
333
+
334
+
335
+ class up_conv(nn.Module):
336
+ def __init__(self,ch_in,ch_out):
337
+ super(up_conv,self).__init__()
338
+ self.up = nn.Sequential(
339
+ nn.Upsample(scale_factor=2),
340
+ nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
341
+ nn.BatchNorm2d(ch_out),
342
+ nn.ReLU(inplace=True)
343
+ )
344
+
345
+ def forward(self,x):
346
+ x = self.up(x)
347
+ return x
348
+
349
+
350
+ class Recurrent_block(nn.Module):
351
+ def __init__(self,ch_out,t=2):
352
+ super(Recurrent_block,self).__init__()
353
+ self.t = t
354
+ self.ch_out = ch_out
355
+ self.conv = nn.Sequential(
356
+ nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
357
+ nn.BatchNorm2d(ch_out),
358
+ nn.ReLU(inplace=True)
359
+ )
360
+
361
+ def forward(self,x):
362
+ for i in range(self.t):
363
+
364
+ if i==0:
365
+ x1 = self.conv(x)
366
+
367
+ x1 = self.conv(x+x1)
368
+ return x1
369
+
370
+
371
+ class RRCNN_block(nn.Module):
372
+ def __init__(self,ch_in,ch_out,t=2):
373
+ super(RRCNN_block,self).__init__()
374
+ self.RCNN = nn.Sequential(
375
+ Recurrent_block(ch_out,t=t),
376
+ Recurrent_block(ch_out,t=t)
377
+ )
378
+ self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
379
+
380
+ def forward(self,x):
381
+ x = self.Conv_1x1(x)
382
+ x1 = self.RCNN(x)
383
+ return x+x1
384
+
385
+
386
+ class single_conv(nn.Module):
387
+ def __init__(self,ch_in,ch_out):
388
+ super(single_conv,self).__init__()
389
+ self.conv = nn.Sequential(
390
+ nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
391
+ nn.BatchNorm2d(ch_out),
392
+ nn.ReLU(inplace=True)
393
+ )
394
+
395
+ def forward(self,x):
396
+ x = self.conv(x)
397
+ return x
398
+
399
+
400
+ class Attention_block(nn.Module):
401
+ def __init__(self,F_g, F_l, F_int):
402
+ super(Attention_block,self).__init__()
403
+ self.W_g = nn.Sequential(
404
+ nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
405
+ nn.BatchNorm2d(F_int)
406
+ )
407
+
408
+ self.W_x = nn.Sequential(
409
+ nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
410
+ nn.BatchNorm2d(F_int)
411
+ )
412
+
413
+ self.psi = nn.Sequential(
414
+ nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
415
+ nn.BatchNorm2d(1),
416
+ nn.Sigmoid()
417
+ )
418
+
419
+ self.relu = nn.ReLU(inplace=True)
420
+
421
+ def forward(self,g,x):
422
+ g1 = self.W_g(g)
423
+ x1 = self.W_x(x)
424
+ psi = self.relu(g1+x1)
425
+ psi = self.psi(psi)
426
+
427
+ return x*psi
428
+
429
+
430
+ class R2AttUNetDecoder(nn.Module):
431
+ def __init__(self, channels, t=2):
432
+ super(R2AttUNetDecoder,self).__init__()
433
+
434
+ self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
435
+
436
+ self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
437
+ self.Att5 = Attention_block(F_g=channels[3], F_l=channels[3], F_int=channels[3]//2)
438
+ self.Up_RRCNN5 = RRCNN_block(ch_in=2 * channels[3], ch_out=channels[3], t=t)
439
+
440
+ self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
441
+ self.Att4 = Attention_block(F_g=channels[2], F_l=channels[2], F_int=channels[2]//2)
442
+ self.Up_RRCNN4 = RRCNN_block(ch_in=2 * channels[2], ch_out=channels[2], t=t)
443
+
444
+ self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
445
+ self.Att3 = Attention_block(F_g=channels[1], F_l=channels[1], F_int=channels[1]//2)
446
+ self.Up_RRCNN3 = RRCNN_block(ch_in=2 * channels[1], ch_out=channels[1], t=t)
447
+
448
+ self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
449
+ self.Att2 = Attention_block(F_g=channels[0], F_l=channels[0], F_int=channels[0]//2)
450
+ self.Up_RRCNN2 = RRCNN_block(ch_in=2 * channels[0], ch_out=channels[0], t=t)
451
+
452
+ def forward(self, x1, x2, x3, x4, x5):
453
+
454
+ out = self.Up5(x5)
455
+ x4_att = self.Att5(g=out, x=x4)
456
+ out = torch.cat((x4_att, out),dim=1)
457
+ out = self.Up_RRCNN5(out)
458
+
459
+ out = self.Up4(out)
460
+ x3_att = self.Att4(g=out, x=x3)
461
+ out = torch.cat((x3_att, out),dim=1)
462
+ out = self.Up_RRCNN4(out)
463
+
464
+ out = self.Up3(out)
465
+ x2_att = self.Att3(g=out, x=x2)
466
+ out = torch.cat((x2_att, out),dim=1)
467
+ out = self.Up_RRCNN3(out)
468
+
469
+ out = self.Up2(out)
470
+ x1_att = self.Att2(g=out, x=x1)
471
+ out = torch.cat((x1_att, out),dim=1)
472
+ out = self.Up_RRCNN2(out)
473
+
474
+ out = self.Upsample(out)
475
+
476
+ return out
477
+
478
+
479
+ class ConvBlock(nn.Module):
480
+ def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=0, bias=True):
481
+ super(ConvBlock, self).__init__()
482
+ self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
483
+ self.bn1 = nn.BatchNorm2d(ch_out)
484
+ self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
485
+ self.bn2 = nn.BatchNorm2d(ch_out)
486
+ self.activate = nn.LeakyReLU(negative_slope=0.01)
487
+
488
+ for m in self.modules():
489
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
490
+ nn.init.kaiming_normal_(m.weight)
491
+ if m.bias is not None:
492
+ nn.init.constant_(m.bias, 0)
493
+ elif isinstance(m, nn.BatchNorm2d):
494
+ nn.init.constant_(m.weight, 1)
495
+ nn.init.constant_(m.bias, 0)
496
+
497
+ def forward(self, x):
498
+ out = self.conv1(x)
499
+ out = self.bn1(out)
500
+ out = self.activate(out)
501
+
502
+ out = self.conv2(out)
503
+ out = self.bn2(out)
504
+ out = self.activate(out)
505
+ return out
506
+
507
+
508
+ class UNetDecoder(nn.Module):
509
+ def __init__(self, channels):
510
+ super(UNetDecoder,self).__init__()
511
+
512
+ self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
513
+
514
+ self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
515
+ self.conv5 = ConvBlock(ch_in=2 * channels[3], ch_out=channels[3], kernel_size=3, stride=1, padding=1)
516
+
517
+ self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
518
+ self.conv4 = ConvBlock(ch_in=2 * channels[2], ch_out=channels[2], kernel_size=3, stride=1, padding=1)
519
+
520
+ self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
521
+ self.conv3 = ConvBlock(ch_in=2 * channels[1], ch_out=channels[1], kernel_size=3, stride=1, padding=1)
522
+
523
+ self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
524
+ self.conv2 = ConvBlock(ch_in=2 * channels[0], ch_out=channels[0], kernel_size=3, stride=1, padding=1)
525
+
526
+ def forward(self, x1, x2, x3, x4, x5):
527
+
528
+ out = self.Up5(x5)
529
+ out = torch.cat((x4, out),dim=1)
530
+ out = self.conv5(out)
531
+
532
+ out = self.Up4(out)
533
+ out = torch.cat((x3, out),dim=1)
534
+ out = self.conv4(out)
535
+
536
+ out = self.Up3(out)
537
+ out = torch.cat((x2, out),dim=1)
538
+ out = self.conv3(out)
539
+
540
+ out = self.Up2(out)
541
+ out = torch.cat((x1, out),dim=1)
542
+ out = self.conv2(out)
543
+
544
+ out = self.Upsample(out)
545
+
546
+ return out
547
+
548
+
549
+ class U_Net_P(nn.Module):
550
+ def __init__(self, encoder, decoder, output_ch, num_classes):
551
+ super(U_Net_P, self).__init__()
552
+
553
+ self.encoder = encoder
554
+ self.decoder = decoder
555
+
556
+ self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
557
+
558
+
559
+ def forward(self, x):
560
+ # encoding path
561
+ x1, x2, x3, x4, x5 = self.encoder(x)
562
+ x = self.decoder(x1, x2, x3, x4, x5)
563
+ x = self.Last_Conv(x)
564
+
565
+ return x
566
+
567
+
568
+ class Prompt_U_Net_P_DCP(nn.Module):
569
+ def __init__(self, encoder, decoder, output_ch, num_classes, dataset_idx, encoder_channels, prompt_init, pos_promot_channels, cha_promot_channels, embed_ratio, strides, local_window_sizes, att_fusion, use_conv):
570
+ super(Prompt_U_Net_P_DCP, self).__init__()
571
+ self.dataset_idx = dataset_idx
572
+ self.local_window_sizes = local_window_sizes
573
+
574
+ self.encoder = encoder
575
+ self.decoder = decoder
576
+
577
+ self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
578
+ if prompt_init == 'zero':
579
+ p_init = torch.zeros
580
+ elif prompt_init == 'one':
581
+ p_init = torch.ones
582
+ elif prompt_init == 'rand':
583
+ p_init = rand
584
+
585
+ else:
586
+ raise Exception(prompt_init)
587
+
588
+ self.pos_promot_channels = pos_promot_channels
589
+ pos_p1 = p_init((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]))
590
+ pos_p2 = p_init((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]))
591
+ pos_p3 = p_init((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]))
592
+ pos_p4 = p_init((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]))
593
+ pos_p5 = p_init((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]))
594
+ self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
595
+ self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
596
+ self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
597
+ self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
598
+ self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
599
+
600
+ self.cha_promot_channels = cha_promot_channels
601
+ cha_p1 = p_init((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]))
602
+ cha_p2 = p_init((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]))
603
+ cha_p3 = p_init((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]))
604
+ cha_p4 = p_init((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]))
605
+ cha_p5 = p_init((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]))
606
+ self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
607
+ self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
608
+ self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
609
+ self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
610
+ self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
611
+
612
+ self.strides = strides
613
+
614
+ self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
615
+ self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
616
+ self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
617
+ self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
618
+ self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
619
+
620
+ def get_cha_prompts(self, dataset_idx, batch_size):
621
+ if len(dataset_idx) != batch_size:
622
+ raise Exception(dataset_idx, self.dataset_idx, batch_size)
623
+ promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
624
+ promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
625
+ promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
626
+ promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
627
+ promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
628
+ return promots1, promots2, promots3, promots4, promots5
629
+
630
+ def get_pos_prompts(self, dataset_idx, batch_size):
631
+ if len(dataset_idx) != batch_size:
632
+ raise Exception(dataset_idx, self.dataset_idx)
633
+ promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
634
+ promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
635
+ promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
636
+ promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
637
+ promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
638
+ return promots1, promots2, promots3, promots4, promots5
639
+
640
+ def forward(self, x, dataset_idx, return_features=False):
641
+
642
+ if isinstance(dataset_idx, torch.Tensor):
643
+ dataset_idx = list(dataset_idx.cpu().numpy())
644
+ cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
645
+ pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
646
+ x1, x2, x3, x4, x5 = self.encoder(x)
647
+
648
+ if return_features:
649
+ pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
650
+
651
+ h1, w1 = x1.size()[2:]
652
+ h2, w2 = x2.size()[2:]
653
+ h3, w3 = x3.size()[2:]
654
+ h4, w4 = x4.size()[2:]
655
+ h5, w5 = x5.size()[2:]
656
+ x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
657
+ x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
658
+ x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
659
+ x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
660
+ x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
661
+
662
+ cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
663
+ cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
664
+ cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
665
+ cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
666
+ cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
667
+
668
+ pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
669
+ pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
670
+ pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
671
+ pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
672
+ pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
673
+
674
+ cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
675
+ pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
676
+
677
+ x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
678
+
679
+ x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
680
+ x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
681
+ x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
682
+ x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
683
+ x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
684
+
685
+ if return_features:
686
+ pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
687
+
688
+ return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
689
+
690
+ x = self.decoder(x1, x2, x3, x4, x5)
691
+ x = self.Last_Conv(x)
692
+
693
+ return x
694
+
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import build_model
models/backbones/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backbones import build_backbone
models/backbones/backbones.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from timm.models import efficientnet, convnext
4
+
5
+
6
+ def build_backbone(model_name, pretrained):
7
+ model = getattr(Backbones, model_name)(pretrained=pretrained)
8
+ return model
9
+
10
+
11
+ class Backbones(object):
12
+ @staticmethod
13
+ def efficientnet_b3_p(pretrained):
14
+ # channels: 24, 12, 40, 120, 384
15
+ # for test, pretrained can be set to False
16
+ model = efficientnet.efficientnet_b3_pruned(pretrained=pretrained, features_only=True)
17
+
18
+ '''
19
+ # pre-downloaded weights
20
+ cp_path = os.path.join('checkpoints', 'effnetb3_pruned-59ecf72d.pth')
21
+ state_dict = torch.load(cp_path, map_location=torch.device('cpu'))
22
+ model.load_state_dict(state_dict=state_dict, strict=False)'''
23
+ return model
24
+
25
+
models/crit/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .focal_loss import BFLoss
2
+ from .mmd import MMDLinear
3
+ from .dice import DiceLoss, DiceBCE
4
+ from .get_bd import generate_BD
models/crit/dice.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ e = 1-10
7
+
8
+
9
+ def dice_loss(pred, target, need_sigmoid=True):
10
+ assert target.size() == pred.size()
11
+ if need_sigmoid:
12
+ pred = torch.sigmoid(pred)
13
+ intersect = 2 * (pred * target).sum() + e
14
+ union = (pred * pred).sum() + (target * target).sum() + e
15
+ return 1 - intersect / union
16
+
17
+
18
+ class DiceLoss(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, pred, target):
23
+ return dice_loss(pred=pred, target=target)
24
+
25
+
26
+ class DiceBCE(nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+
30
+ def forward(self, pred, target):
31
+ return 0.5 * dice_loss(pred=pred, target=target) + \
32
+ 0.5 * F.binary_cross_entropy_with_logits(input=pred, target=target)
33
+
34
+
35
+
36
+
models/crit/focal_loss.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+ def binary_focal_loss(pred, target, alpha=0.5, gamma=2):
6
+ assert pred.size() == target.size()
7
+ pred = torch.sigmoid(pred)
8
+ e = 1e-5
9
+ loss = alpha * target * (1 - pred) ** gamma * (pred + e).log() + (1 - alpha) * (1 - target) * pred ** gamma * (1 - pred + e).log()
10
+ loss = loss / (0.5 ** gamma)
11
+ return -loss.mean()
12
+
13
+
14
+ class BFLoss(nn.Module):
15
+ def __init__(self, alpha=0.5, gamma=2):
16
+ super(BFLoss, self).__init__()
17
+ # alpha: the weight of fg
18
+ self.gamma = gamma
19
+ self.alpha = alpha
20
+
21
+ def forward(self, pred, target, *args, **kwargs):
22
+ return binary_focal_loss(pred, target, alpha=self.alpha, gamma=self.gamma)
23
+
models/crit/get_bd.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def generate_BD(mask):
6
+ #print(mask.size())
7
+ # img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
8
+ # mask = mask.float()
9
+ mask = torch.abs(mask - F.max_pool2d(mask, 3, 1, 1))
10
+ mask = mask.detach()
11
+
12
+ return mask
models/crit/mmd.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from math import gcd
4
+
5
+
6
+ def mmd_linear(f_of_X, f_of_Y):
7
+ delta = f_of_X - f_of_Y
8
+ loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
9
+ return loss
10
+
11
+ class MMDLinear(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, fea_source, fea_target):
16
+ n_s, d_s = fea_source.size()
17
+ n_t, d_t = fea_target.size()
18
+
19
+ assert d_s == d_t
20
+
21
+ if n_s != n_t:
22
+ n = int(n_s * n_t / gcd(n_s, n_t)) # 最小公倍数
23
+
24
+ fea_source = fea_source.repeat((int(n / n_s), 1))
25
+ fea_target = fea_target.repeat((int(n / n_t), 1))
26
+ return mmd_linear(fea_source, fea_target)
27
+
28
+
models/jtfn.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ICCV2021, Joint Topology-preserving and Feature-refinement Network for Curvilinear Structure Segmentation
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .UNet_p import MultiHeadAttention2D_Dual2_2, rand, window_partition, window_unpartition, prompt_partition, OneLayerRes
6
+
7
+
8
+
9
+
10
+ class SpatialAttention(nn.Module):
11
+ def __init__(self):
12
+ super(SpatialAttention, self).__init__()
13
+ self.conv = nn.Sequential(
14
+ nn.Conv2d(2, 1, kernel_size=(3, 3), padding=(1, 1)),
15
+ nn.Conv2d(1, 1, kernel_size=(5, 5), padding=(2, 2)),
16
+ nn.Sigmoid()
17
+ )
18
+
19
+ def forward(self, x):
20
+ avg_out = torch.mean(x, dim=1, keepdim=True)
21
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
22
+ x = torch.cat([avg_out, max_out], dim=1)
23
+ x = self.conv(x)
24
+ return x
25
+
26
+
27
+ class ChannelAttention(nn.Module):
28
+ def __init__(self, channel, reduction=2):
29
+ super(ChannelAttention, self).__init__()
30
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
31
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
32
+
33
+ self.fc1 = nn.Conv2d(channel, channel // reduction, 1, bias=False)
34
+ self.fc2 = nn.Conv2d(channel // reduction, channel, 1, bias=False)
35
+ self.activate = nn.Sigmoid()
36
+
37
+ def forward(self, x):
38
+ avg_out = self.fc2(self.fc1(self.avg_pool(x)))
39
+ max_out = self.fc2(self.fc1(self.max_pool(x)))
40
+ out = avg_out + max_out
41
+ out = self.activate(out)
42
+ return out
43
+
44
+
45
+ class GAU(nn.Module):
46
+ def __init__(self, in_channels, use_gau=True, reduce_dim=False, out_channels=None):
47
+ super(GAU, self).__init__()
48
+ self.use_gau = use_gau
49
+ self.reduce_dim = reduce_dim
50
+
51
+ if self.reduce_dim:
52
+ self.down_conv = nn.Sequential(
53
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
54
+ nn.BatchNorm2d(out_channels),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+ in_channels = out_channels
58
+
59
+ if self.use_gau:
60
+
61
+ self.sa = SpatialAttention()
62
+ self.ca = ChannelAttention(in_channels)
63
+
64
+ self.reset_gate = nn.Sequential(
65
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=2, dilation=2),
66
+ nn.BatchNorm2d(out_channels),
67
+ nn.ReLU(inplace=True),
68
+ )
69
+
70
+ def forward(self, x, y):
71
+ if self.reduce_dim:
72
+ x = self.down_conv(x)
73
+
74
+ if self.use_gau:
75
+ y = F.interpolate(y, x.shape[-2:], mode='bilinear', align_corners=True)
76
+
77
+ comx = x * y
78
+ resx = x * (1 - y) # bs, c, h, w
79
+
80
+ x_sa = self.sa(resx) # bs, 1, h, w
81
+ x_ca = self.ca(resx) # bs, c, 1, 1
82
+
83
+ O = self.reset_gate(comx)
84
+ M = x_sa * x_ca
85
+
86
+ RF = M * x + (1 - M) * O
87
+ else:
88
+ RF = x
89
+ return RF
90
+
91
+
92
+ class FIM(nn.Module):
93
+
94
+ def __init__(self, in_channels, out_channels, f_channels, use_topo=True, up=True, bottom=False):
95
+ super(FIM, self).__init__()
96
+ self.use_topo = use_topo
97
+ self.up = up
98
+ self.bottom = bottom
99
+
100
+ if self.up:
101
+ self.up_s = nn.Sequential(
102
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
103
+ nn.BatchNorm2d(out_channels),
104
+ nn.ReLU(inplace=True)
105
+ )
106
+ self.up_t = nn.Sequential(
107
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
108
+ nn.BatchNorm2d(out_channels),
109
+ nn.ReLU(inplace=True)
110
+ )
111
+ else:
112
+ self.up_s = nn.Sequential(
113
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
114
+ nn.BatchNorm2d(out_channels),
115
+ nn.ReLU(inplace=True)
116
+ )
117
+ self.up_t = nn.Sequential(
118
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
119
+ nn.BatchNorm2d(out_channels),
120
+ nn.ReLU(inplace=True)
121
+ )
122
+
123
+ self.decoder_s = nn.Sequential(
124
+ nn.Conv2d(out_channels + f_channels, out_channels, kernel_size=3, stride=1, padding=1),
125
+ nn.BatchNorm2d(out_channels),
126
+ nn.ReLU(inplace=True),
127
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
128
+ nn.BatchNorm2d(out_channels),
129
+ nn.ReLU(inplace=True)
130
+ )
131
+
132
+ '''self.inner_s = nn.Sequential(
133
+ nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
134
+ nn.Sigmoid()
135
+ )'''
136
+
137
+ if self.bottom:
138
+ self.st = nn.Sequential(
139
+ nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1),
140
+ nn.BatchNorm2d(in_channels),
141
+ nn.ReLU(inplace=True)
142
+ )
143
+
144
+ if self.use_topo:
145
+ self.decoder_t = nn.Sequential(
146
+ nn.Conv2d(out_channels + out_channels, out_channels, kernel_size=3, stride=1, padding=1),
147
+ nn.BatchNorm2d(out_channels),
148
+ nn.ReLU(inplace=True)
149
+ )
150
+
151
+ self.s_to_t = nn.Sequential(
152
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
153
+ nn.BatchNorm2d(out_channels),
154
+ nn.ReLU(inplace=True)
155
+ )
156
+
157
+ self.t_to_s = nn.Sequential(
158
+ nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
159
+ nn.BatchNorm2d(out_channels),
160
+ nn.ReLU(inplace=True)
161
+ )
162
+
163
+ self.res_s = nn.Sequential(
164
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1),
165
+ nn.BatchNorm2d(out_channels),
166
+ nn.ReLU(inplace=True)
167
+ )
168
+
169
+ '''self.inner_t = nn.Sequential(
170
+ nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
171
+ nn.Sigmoid()
172
+ )'''
173
+
174
+ def forward(self, x_s, x_t, rf):
175
+ if self.use_topo:
176
+ if self.bottom:
177
+ x_t = self.st(x_t)
178
+ #bs, c, h, w = x_s.shape
179
+ x_s = self.up_s(x_s)
180
+ x_t = self.up_t(x_t)
181
+
182
+ # padding
183
+ diffY = rf.size()[2] - x_s.size()[2]
184
+ diffX = rf.size()[3] - x_s.size()[3]
185
+
186
+ x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
187
+ diffY // 2, diffY - diffY // 2])
188
+ x_t = F.pad(x_t, [diffX // 2, diffX - diffX // 2,
189
+ diffY // 2, diffY - diffY // 2])
190
+
191
+ rf_s = torch.cat((x_s, rf), dim=1)
192
+ s = self.decoder_s(rf_s)
193
+ s_t = self.s_to_t(s)
194
+
195
+ t = torch.cat((x_t, s_t), dim=1)
196
+ x_t = self.decoder_t(t)
197
+ t_s = self.t_to_s(x_t)
198
+
199
+ s_res = self.res_s(torch.cat((s, t_s), dim=1))
200
+
201
+ x_s = s + s_res
202
+ # t_cls = self.inner_t(x_t)
203
+ # s_cls = self.inner_s(x_s)
204
+ else:
205
+ x_s = self.up_s(x_s)
206
+ #x_b = self.up_b(x_b)
207
+ # padding
208
+ diffY = rf.size()[2] - x_s.size()[2]
209
+ diffX = rf.size()[3] - x_s.size()[3]
210
+
211
+ x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
212
+ diffY // 2, diffY - diffY // 2])
213
+
214
+ rf_s = torch.cat((x_s, rf), dim=1)
215
+ s = self.decoder_s(rf_s)
216
+ x_s = s
217
+ x_t = x_s
218
+ #t_cls = None
219
+ #s_cls = self.inner_s(x_s)
220
+ return x_s, x_t
221
+
222
+
223
+ class JTFNDecoder(nn.Module):
224
+ def __init__(self, channels, use_topo) -> None:
225
+ super().__init__()
226
+ self.skip_blocks = []
227
+ for i in range(5):
228
+ self.skip_blocks.append(GAU(channels[i], use_gau=True, reduce_dim=False, out_channels=channels[i]))
229
+ self.fims = []
230
+ index = 3
231
+ for i in range(4):
232
+ if i == index:
233
+ self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=True))
234
+ else:
235
+ self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=False))
236
+ self.skip_blocks = nn.ModuleList(self.skip_blocks)
237
+ self.fims = nn.ModuleList(self.fims)
238
+
239
+ def forward(self, x1, x2, x3, x4, x5, y):
240
+ x1 = self.skip_blocks[0](x1, y)
241
+ x2 = self.skip_blocks[1](x2, y)
242
+ x3 = self.skip_blocks[2](x3, y)
243
+ x4 = self.skip_blocks[3](x4, y)
244
+ x5 = self.skip_blocks[4](x5, y)
245
+
246
+ x5_seg, x5_bou = x5, x5
247
+
248
+ x4_seg, x4_bou = self.fims[3](x5_seg, x5_bou, x4)
249
+ x3_seg, x3_bou = self.fims[2](x4_seg, x4_bou, x3)
250
+ x2_seg, x2_bou = self.fims[1](x3_seg, x3_bou, x2)
251
+ x1_seg, x1_bou = self.fims[0](x2_seg, x2_bou, x1)
252
+
253
+
254
+ return [x1_seg, x2_seg, x3_seg, x4_seg], [x1_bou, x2_bou, x3_bou, x4_bou]
255
+
256
+
257
+ class JTFN(nn.Module):
258
+ def __init__(self, encoder, decoder, channels, num_classes, steps) -> None:
259
+ super().__init__()
260
+ self.encoder = encoder
261
+ self.decoder = decoder
262
+ self.num_classes = num_classes
263
+ self.steps = steps
264
+
265
+ self.conv_seg1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
266
+ self.conv_seg2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
267
+ self.conv_seg3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
268
+ self.conv_seg4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
269
+
270
+ self.conv_bou1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
271
+ self.conv_bou2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
272
+ self.conv_bou3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
273
+ self.conv_bou4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
274
+
275
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
276
+
277
+ def forward(self, x):
278
+ B, C, H, W = x.shape
279
+ y = torch.zeros([B, self.num_classes, H, W], device=x.device)
280
+
281
+ x1, x2, x3, x4, x5 = self.encoder(x)
282
+
283
+ outputs = {}
284
+ for i in range(self.steps):
285
+ segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
286
+ x1_seg, x2_seg, x3_seg, x4_seg = segs
287
+ x1_bou, x2_bou, x3_bou, x4_bou = bous
288
+
289
+ x1_seg = self.conv_seg1_head(x1_seg)
290
+ x2_seg = self.conv_seg2_head(x2_seg)
291
+ x3_seg = self.conv_seg3_head(x3_seg)
292
+ x4_seg = self.conv_seg4_head(x4_seg)
293
+
294
+ x1_bou = self.conv_bou1_head(x1_bou)
295
+ x2_bou = self.conv_bou2_head(x2_bou)
296
+ x3_bou = self.conv_bou3_head(x3_bou)
297
+ x4_bou = self.conv_bou4_head(x4_bou)
298
+
299
+ y = x1_seg
300
+ outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
301
+ outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
302
+ y = self.upsample(y)
303
+ outputs['output'] = y
304
+ return outputs
305
+
306
+ def encoder_forward(self, x, dataset_idx):
307
+ # efficient net
308
+ x = self.encoder.conv_stem(x)
309
+ x = self.encoder.bn1(x)
310
+ features = []
311
+ if 0 in self.encoder._stage_out_idx:
312
+ features.append(x) # add stem out
313
+ for i in range(len(self.encoder.blocks)):
314
+ for j, l in enumerate(self.encoder.blocks[i]):
315
+ if j == len(self.encoder.blocks[i]) - 1 and i + 1 in self.encoder._stage_out_idx:
316
+ x = l(x, dataset_idx)
317
+ else:
318
+ x = l(x)
319
+ if i + 1 in self.encoder._stage_out_idx:
320
+ features.append(x)
321
+ return features
322
+
323
+
324
+
325
+ class JTFN_DCP(JTFN):
326
+ def __init__(self, encoder, decoder, channels, num_classes, steps, dataset_idx,
327
+ local_window_sizes, encoder_channels, pos_promot_channels, cha_promot_channels,
328
+ embed_ratio, strides, att_fusion, use_conv) -> None:
329
+ super().__init__(encoder, decoder, channels, num_classes, steps)
330
+ self.dataset_idx = dataset_idx
331
+ self.local_window_sizes = local_window_sizes
332
+
333
+ self.pos_promot_channels = pos_promot_channels
334
+ pos_p1 = rand((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]), val=3. / encoder_channels[0] ** 0.5)
335
+ pos_p2 = rand((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]), val=3. / encoder_channels[1] ** 0.5)
336
+ pos_p3 = rand((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]), val=3. / encoder_channels[2] ** 0.5)
337
+ pos_p4 = rand((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]), val=3. / encoder_channels[3] ** 0.5)
338
+ pos_p5 = rand((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]), val=3. / encoder_channels[4] ** 0.5)
339
+ self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
340
+ self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
341
+ self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
342
+ self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
343
+ self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
344
+
345
+ self.cha_promot_channels = cha_promot_channels
346
+ cha_p1 = rand((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]), val=3. / local_window_sizes[0])
347
+ cha_p2 = rand((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]), val=3. / local_window_sizes[1])
348
+ cha_p3 = rand((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]), val=3. / local_window_sizes[2])
349
+ cha_p4 = rand((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]), val=3. / local_window_sizes[3])
350
+ cha_p5 = rand((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]), val=3. / local_window_sizes[4])
351
+ self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
352
+ self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
353
+ self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
354
+ self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
355
+ self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
356
+
357
+ self.strides = strides
358
+
359
+ self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
360
+ self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
361
+ self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
362
+ self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
363
+ self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
364
+
365
+ def get_cha_prompts(self, dataset_idx, batch_size):
366
+ if len(dataset_idx) != batch_size:
367
+ raise Exception(dataset_idx, self.dataset_idx, batch_size)
368
+ # print(dataset_idx, '***')
369
+ promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
370
+ promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
371
+ promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
372
+ promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
373
+ promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
374
+ return promots1, promots2, promots3, promots4, promots5
375
+
376
+ def get_pos_prompts(self, dataset_idx, batch_size):
377
+ if len(dataset_idx) != batch_size:
378
+ raise Exception(dataset_idx, self.dataset_idx)
379
+ # print(dataset_idx, '***')
380
+ promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
381
+ promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
382
+ promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
383
+ promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
384
+ promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
385
+ return promots1, promots2, promots3, promots4, promots5
386
+
387
+ def forward(self, x, dataset_idx, return_features=False):
388
+ if isinstance(dataset_idx, torch.Tensor):
389
+ dataset_idx = list(dataset_idx.cpu().numpy())
390
+ #print(dataset_idx)
391
+ cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
392
+ pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
393
+
394
+ B, C, H, W = x.shape
395
+ y = torch.zeros([B, self.num_classes, H, W], device=x.device)
396
+
397
+ x1, x2, x3, x4, x5 = self.encoder(x)
398
+
399
+ if return_features:
400
+ pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
401
+ h1, w1 = x1.size()[2:]
402
+ h2, w2 = x2.size()[2:]
403
+ h3, w3 = x3.size()[2:]
404
+ h4, w4 = x4.size()[2:]
405
+ h5, w5 = x5.size()[2:]
406
+ x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
407
+ x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
408
+ x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
409
+ x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
410
+ x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
411
+
412
+ cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
413
+ cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
414
+ cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
415
+ cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
416
+ cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
417
+
418
+ pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
419
+ pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
420
+ pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
421
+ pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
422
+ pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
423
+
424
+ #print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
425
+ cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
426
+ pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
427
+
428
+ #print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
429
+ x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
430
+
431
+ x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
432
+ x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
433
+ x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
434
+ x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
435
+ x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
436
+
437
+ if return_features:
438
+ pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
439
+
440
+ return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
441
+
442
+ outputs = {}
443
+ for i in range(self.steps):
444
+ segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
445
+ x1_seg, x2_seg, x3_seg, x4_seg = segs
446
+ x1_bou, x2_bou, x3_bou, x4_bou = bous
447
+
448
+ x1_seg = self.conv_seg1_head(x1_seg)
449
+ x2_seg = self.conv_seg2_head(x2_seg)
450
+ x3_seg = self.conv_seg3_head(x3_seg)
451
+ x4_seg = self.conv_seg4_head(x4_seg)
452
+
453
+ x1_bou = self.conv_bou1_head(x1_bou)
454
+ x2_bou = self.conv_bou2_head(x2_bou)
455
+ x3_bou = self.conv_bou3_head(x3_bou)
456
+ x4_bou = self.conv_bou4_head(x4_bou)
457
+
458
+ y = x1_seg
459
+ outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
460
+ outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
461
+ y = self.upsample(y)
462
+ outputs['output'] = y
463
+ return outputs
464
+
465
+
models/models.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .processor import Processor, DCPProcessor, JTFNProcessor, JTFNDCPProcessor
2
+ from .UNet_p import U_Net_P, R2AttUNetDecoder, UNetDecoder, Prompt_U_Net_P_DCP
3
+ from .jtfn import JTFN, JTFNDecoder, JTFN_DCP
4
+ from .backbones import build_backbone
5
+
6
+
7
+ def build_model(model_name, model_params, training, dataset_idx, pretrained):
8
+ model = getattr(Models, model_name)(model_params=model_params, training=training, dataset_idx=dataset_idx, pretrained=pretrained)
9
+ return model
10
+
11
+
12
+ class Models(object):
13
+ @staticmethod
14
+ def effi_b3_p_unet(model_params, training, dataset_idx, pretrained=True):
15
+ n_class = model_params['n_class']
16
+ channels = (24, 12, 40, 120, 384)
17
+
18
+ encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
19
+ decoder = UNetDecoder(channels=channels)
20
+
21
+ seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
22
+ model = Processor(model=seg_net, training_params=model_params, training=training)
23
+ return model
24
+
25
+
26
+ @staticmethod
27
+ def effi_b3_p_r2attunet(model_params, training, dataset_idx, pretrained=True):
28
+ n_class = model_params['n_class']
29
+ channels = (24, 12, 40, 120, 384)
30
+
31
+ encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
32
+ decoder = R2AttUNetDecoder(channels=channels)
33
+
34
+ seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
35
+ model = Processor(model=seg_net, training_params=model_params, training=training)
36
+ return model
37
+
38
+ @staticmethod
39
+ def effi_b3_p_jtfn(model_params, training, dataset_idx, pretrained=True):
40
+ n_class = model_params['n_class']
41
+ channels = (24, 12, 40, 120, 384)
42
+ steps = model_params['steps']
43
+
44
+ encoder = build_backbone('efficientnet_b3_p')
45
+ decoder = JTFNDecoder(channels=channels, use_topo=True)
46
+
47
+ seg_net = JTFN(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps)
48
+ model = JTFNProcessor(model=seg_net, training_params=model_params, training=training)
49
+ return model
50
+
51
+
52
+ @staticmethod
53
+ def prompt_effi_b3_p_unet_dcp(model_params, training, dataset_idx, pretrained=True):
54
+ n_class = model_params['n_class']
55
+ channels = [24, 12, 40, 120, 384]
56
+
57
+ cha_promot_channels = model_params['cha_promot_channels']
58
+ pos_promot_channels = model_params['pos_promot_channels']
59
+ local_window_sizes = model_params['local_window_sizes']
60
+ att_fusion = model_params['att_fusion']
61
+ prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
62
+ embed_ratio = model_params['embed_ratio']
63
+ strides = model_params['strides']
64
+ use_conv = model_params['use_conv']
65
+
66
+ encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
67
+ decoder = UNetDecoder(channels=channels)
68
+
69
+ seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
70
+ dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
71
+ cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
72
+ embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
73
+ att_fusion=att_fusion, use_conv=use_conv)
74
+
75
+ model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
76
+ return model
77
+
78
+ @staticmethod
79
+ def prompt_effi_b3_p_r2attunet_dcp(model_params, training, dataset_idx, pretrained=True):
80
+ n_class = model_params['n_class']
81
+ channels = [24, 12, 40, 120, 384]
82
+
83
+ cha_promot_channels = model_params['cha_promot_channels']
84
+ pos_promot_channels = model_params['pos_promot_channels']
85
+ local_window_sizes = model_params['local_window_sizes']
86
+ att_fusion = model_params['att_fusion']
87
+ prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
88
+ embed_ratio = model_params['embed_ratio']
89
+ strides = model_params['strides']
90
+ use_conv = model_params['use_conv']
91
+
92
+ encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
93
+ decoder = R2AttUNetDecoder(channels=channels)
94
+
95
+ seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
96
+ dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
97
+ cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
98
+ embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
99
+ att_fusion=att_fusion, use_conv=use_conv)
100
+
101
+ model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
102
+ return model
103
+
104
+
105
+ @staticmethod
106
+ def prompt_effi_b3_p_jtfn_dcp(model_params, training, dataset_idx, pretrained=True):
107
+ n_class = model_params['n_class']
108
+ steps = model_params['steps']
109
+ channels = [24, 12, 40, 120, 384]
110
+
111
+ cha_promot_channels = model_params['cha_promot_channels']
112
+ pos_promot_channels = model_params['pos_promot_channels']
113
+ local_window_sizes = model_params['local_window_sizes']
114
+ att_fusion = model_params['att_fusion']
115
+ embed_ratio = model_params['embed_ratio']
116
+ strides = model_params['strides']
117
+ use_conv = model_params['use_conv']
118
+
119
+ encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
120
+ decoder = JTFNDecoder(channels=channels, use_topo=True)
121
+ seg_net = JTFN_DCP(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps,
122
+ dataset_idx=dataset_idx, local_window_sizes=local_window_sizes,
123
+ encoder_channels=channels,
124
+ cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
125
+ embed_ratio=embed_ratio, strides=strides,
126
+ att_fusion=att_fusion, use_conv=use_conv)
127
+
128
+ model = JTFNDCPProcessor(model=seg_net, training_params=model_params, training=training)
129
+ return model
130
+
131
+
models/optimizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+ import torch.nn as nn
3
+ import torch
4
+ import itertools
5
+
6
+
7
+ def add_full_model_gradient_clipping(optim, clip_norm_val):
8
+
9
+ class FullModelGradientClippingOptimizer(optim):
10
+ def step(self, closure=None):
11
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
12
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
13
+ super().step(closure=closure)
14
+
15
+ return FullModelGradientClippingOptimizer
16
+
17
+
18
+ class Optimizer(object):
19
+ def __init__(self, models, training_params, sep_lr=None, sep_params=None, gradient_clip=0):
20
+
21
+ params = []
22
+ for model in models:
23
+ if isinstance(model, nn.Parameter):
24
+ params += [model]
25
+ else:
26
+ params += list(model.parameters())
27
+ if sep_lr is not None:
28
+ print(sep_lr)
29
+ add_params = []
30
+ for model in sep_params:
31
+ if isinstance(model, nn.Parameter):
32
+ add_params += [model]
33
+ else:
34
+ add_params += list(model.parameters())
35
+ params = [{'params': params},
36
+ {'params': add_params, 'lr': sep_lr}]
37
+
38
+
39
+ self.lr = training_params['lr']
40
+ self.weight_decay = training_params['weight_decay']
41
+ method = training_params['optimizer']
42
+
43
+
44
+ if method == 'SGD':
45
+ self.momentum = training_params['momentum']
46
+ if gradient_clip > 0:
47
+ self.optim = add_full_model_gradient_clipping(optim.SGD, gradient_clip)(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
48
+ else:
49
+ self.optim = optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
50
+ elif method == 'AdamW':
51
+ self.optim = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
52
+ else:
53
+ raise Exception('{} is not supported'.format(method))
54
+
55
+ schedule_name = training_params['lr_schedule']
56
+ schedule_params = training_params['schedule_params']
57
+ if schedule_name == 'CosineAnnealingLR':
58
+ schedule_params['T_max'] = training_params['inter_val'] * 4
59
+ self.lr_schedule = getattr(optim.lr_scheduler, schedule_name)(self.optim, **schedule_params)
60
+
61
+ def update_lr(self):
62
+ self.lr_schedule.step()
63
+
64
+ def z_grad(self):
65
+ self.optim.zero_grad()
66
+
67
+ def g_step(self):
68
+ self.optim.step()
69
+
70
+ def get_lr(self):
71
+ for param_group in self.optim.param_groups:
72
+ return param_group['lr']
models/processor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .optimizer import Optimizer
4
+ from .crit import DiceBCE, generate_BD
5
+ from collections import OrderedDict
6
+
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class BasicProcessor(object):
11
+ def __init__(self) -> None:
12
+ pass
13
+
14
+ def fit(self):
15
+ raise NotImplementedError
16
+
17
+ def predict(self):
18
+ raise NotImplementedError
19
+
20
+ def set_mode(self, mode):
21
+ if mode == 'train':
22
+ self.model.train()
23
+ elif mode == 'eval':
24
+ self.model.eval()
25
+ else:
26
+ raise Exception('Invalid model mode {}'.format(mode))
27
+
28
+ def requires_grad_false(self):
29
+ for param in self.model.parameters():
30
+ param.requires_grad = False
31
+
32
+ def set_device(self, device):
33
+ # print(device)
34
+ if isinstance(device, list):
35
+ if len(device) > 1:
36
+ self.model= nn.DataParallel(self.model, device_ids=device)
37
+ _device = 'cuda'
38
+ else:
39
+ _device = 'cuda:{}'.format(device[0])
40
+ self.model.to(_device)
41
+ else:
42
+ self.model.to(device)
43
+
44
+ def save_model(self, path):
45
+ torch.save(self.model.state_dict(), path)
46
+
47
+ def load_model(self, path):
48
+ state_dict = torch.load(path, map_location='cpu')
49
+
50
+ remove_module = True
51
+ for k, v in state_dict.items():
52
+ if not k.startswith('module.'):
53
+ remove_module = False
54
+ break
55
+ if remove_module:
56
+ # create new OrderedDict that does not contain `module.`
57
+ new_state_dict = OrderedDict()
58
+ for k, v in state_dict.items():
59
+ name = k[7:] #remove 'module'
60
+ new_state_dict[name] = v
61
+
62
+ msg = self.model.load_state_dict(new_state_dict)
63
+ else:
64
+ msg = self.model.load_state_dict(state_dict)
65
+ print(msg)
66
+
67
+
68
+ class Processor(BasicProcessor):
69
+ def __init__(self, model, training_params, training) -> None:
70
+ self.model = model
71
+
72
+ if training:
73
+ self.opt = Optimizer([self.model], training_params)
74
+ self.crit = DiceBCE()
75
+
76
+ def fit(self, xs, ys, device, **kwargs):
77
+ self.opt.z_grad()
78
+
79
+ if len(device) > 1:
80
+ _device = 'cuda'
81
+ else:
82
+ _device = 'cuda:{}'.format(device[0])
83
+ xs = xs.type(torch.FloatTensor).to(_device)
84
+ ys = ys.type(torch.FloatTensor).to(_device)
85
+
86
+ scores = self.model(xs)
87
+ loss = self.crit(scores, ys)
88
+
89
+ loss.backward()
90
+ self.opt.g_step()
91
+ self.opt.update_lr()
92
+
93
+ return scores, loss
94
+
95
+ def predict(self, x, device, **kwargs):
96
+ if len(device) > 1:
97
+ _device = 'cuda'
98
+ else:
99
+ _device = 'cuda:{}'.format(device[0])
100
+ x = x.type(torch.FloatTensor).to(_device)
101
+ return self.model(x)
102
+
103
+
104
+ class DCPProcessor(BasicProcessor):
105
+ def __init__(self, model, training_params, training=True) -> None:
106
+ self.model = model
107
+ if training:
108
+ if 'prompt_lr' in training_params:
109
+ prompt_lr = training_params['prompt_lr']
110
+ self.opt = Optimizer([self.model.encoder, self.model.decoder, self.model.Last_Conv, self.model.att1, self.model.att2, self.model.att3, self.model.att4, self.model.att5], training_params,
111
+ sep_lr=prompt_lr, sep_params=[self.model.cha_promot1, self.model.cha_promot2, self.model.cha_promot3, self.model.cha_promot4, self.model.cha_promot5, self.model.pos_promot1, self.model.pos_promot2, self.model.pos_promot3, self.model.pos_promot4, self.model.pos_promot5])
112
+ else:
113
+ self.opt = Optimizer([self.model], training_params)
114
+ self.crit = DiceBCE()
115
+
116
+ def fit(self, xs, ys, device, **kwargs):
117
+ dataset_idx = kwargs['dataset_idx']
118
+ self.opt.z_grad()
119
+ if len(device) > 1:
120
+ _device = 'cuda'
121
+ else:
122
+ _device = 'cuda:{}'.format(device[0])
123
+
124
+ xs = xs.type(torch.FloatTensor).to(_device)
125
+ ys = ys.type(torch.FloatTensor).to(_device)
126
+
127
+ scores = self.model(xs, dataset_idx)
128
+ loss = self.crit(scores, ys)
129
+
130
+ loss.backward()
131
+
132
+ self.opt.g_step()
133
+ self.opt.update_lr()
134
+
135
+ return scores, loss
136
+
137
+ def predict(self, x, device, **kwargs):
138
+ dataset_idx = kwargs['dataset_idx']
139
+ #print(dataset_idx)
140
+ if isinstance(device, list):
141
+ if len(device) > 1:
142
+ _device = 'cuda'
143
+ else:
144
+ _device = 'cuda:{}'.format(device[0])
145
+ else:
146
+ _device = device
147
+
148
+ x = x.type(torch.FloatTensor).to(_device)
149
+
150
+ return self.model(x, dataset_idx)
151
+
152
+
153
+ class JTFNProcessor(BasicProcessor):
154
+ def __init__(self, model, training_params, training=True) -> None:
155
+ # model_params = training_params['model_params']
156
+ # n_class = model_params['n_class']
157
+
158
+ self.model = model
159
+ self.steps = training_params['steps']
160
+
161
+ if training:
162
+ self.opt = Optimizer([self.model], training_params)
163
+ # self.crit = DiceLoss()
164
+ self.crit = DiceBCE()
165
+
166
+ def fit(self, xs, ys, device, **kwargs):
167
+ self.opt.z_grad()
168
+
169
+ #num_domains = len(xs)
170
+ batch_size = len(xs)
171
+
172
+ if len(device) > 1:
173
+ _device = 'cuda'
174
+ else:
175
+ _device = 'cuda:{}'.format(device[0])
176
+ #xs = torch.concatenate(xs, dim=0).type(torch.FloatTensor).to(_device)
177
+ #ys = torch.concatenate(ys, dim=0).type(torch.FloatTensor).to(_device)
178
+ xs = xs.type(torch.FloatTensor).to(_device)
179
+ ys = ys.type(torch.FloatTensor).to(_device)
180
+
181
+ ys_boundary = generate_BD(ys)
182
+ _, _, h, w = ys.size()
183
+
184
+ outputs = self.model(xs)
185
+ loss = 0
186
+ for i in range(self.steps):
187
+ pred_seg = outputs['step_{}_seg'.format(i)]
188
+ pred_bou = outputs['step_{}_bou'.format(i)]
189
+
190
+ for j in range(len(pred_seg)):
191
+ p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
192
+ p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
193
+
194
+ loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
195
+ loss /= len(pred_seg)
196
+ loss.backward()
197
+ self.opt.g_step()
198
+ self.opt.update_lr()
199
+
200
+
201
+ scores = outputs['output']
202
+ # _, C, H, W = scores.size()
203
+
204
+ # scores = scores.view(num_domains, batch_size, C, H, W)
205
+ # scores = scores.cpu().numpy()
206
+ return scores, loss
207
+
208
+ def predict(self, x, device, **kwargs):
209
+ if len(device) > 1:
210
+ _device = 'cuda'
211
+ else:
212
+ _device = 'cuda:{}'.format(device[0])
213
+ x = x.type(torch.FloatTensor).to(_device)
214
+ outputs = self.model(x)
215
+
216
+ return outputs['output']
217
+
218
+
219
+ class JTFNDCPProcessor(BasicProcessor):
220
+ def __init__(self, model, training_params, training=True) -> None:
221
+ # model_params = training_params['model_params']
222
+ # n_class = model_params['n_class']
223
+
224
+ self.model = model
225
+ self.steps = training_params['steps']
226
+
227
+ if training:
228
+
229
+ self.opt = Optimizer([self.model], training_params)
230
+ # self.crit = DiceLoss()
231
+ self.crit = DiceBCE()
232
+
233
+ def fit(self, xs, ys, device, **kwargs):
234
+ dataset_idx = kwargs['dataset_idx']
235
+ self.opt.z_grad()
236
+
237
+ if len(device) > 1:
238
+ _device = 'cuda'
239
+ else:
240
+ _device = 'cuda:{}'.format(device[0])
241
+ xs = xs.type(torch.FloatTensor).to(_device)
242
+ ys = ys.type(torch.FloatTensor).to(_device)
243
+
244
+ ys_boundary = generate_BD(ys)
245
+ _, _, h, w = ys.size()
246
+
247
+ outputs = self.model(xs, dataset_idx)
248
+ loss = 0
249
+ for i in range(self.steps):
250
+ pred_seg = outputs['step_{}_seg'.format(i)]
251
+ pred_bou = outputs['step_{}_bou'.format(i)]
252
+
253
+ for j in range(len(pred_seg)):
254
+ p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
255
+ p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
256
+
257
+ loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
258
+ loss /= len(pred_seg)
259
+ loss.backward()
260
+ self.opt.g_step()
261
+ self.opt.update_lr()
262
+
263
+ scores = outputs['output']
264
+
265
+ return scores, loss
266
+
267
+ def predict(self, x, device, **kwargs):
268
+ dataset_idx = kwargs['dataset_idx']
269
+ if len(device) > 1:
270
+ _device = 'cuda'
271
+ else:
272
+ _device = 'cuda:{}'.format(device[0])
273
+ x = x.type(torch.FloatTensor).to(_device)
274
+
275
+ outputs = self.model(x, dataset_idx)
276
+
277
+ return outputs['output']
278
+
279
+
280
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.26.4
2
+ opencv-python==4.9.0.80
3
+ torch==2.3.0
4
+ timm==1.0.3