Spaces:
Sleeping
Sleeping
qijie.wei
commited on
Commit
·
c5f4ee2
1
Parent(s):
55cc90a
first commit
Browse files- .gitignore +6 -0
- README.md +2 -1
- app.py +33 -0
- inference.py +88 -0
- models/UNet_p.py +694 -0
- models/__init__.py +1 -0
- models/backbones/__init__.py +1 -0
- models/backbones/backbones.py +25 -0
- models/crit/__init__.py +4 -0
- models/crit/dice.py +36 -0
- models/crit/focal_loss.py +23 -0
- models/crit/get_bd.py +12 -0
- models/crit/mmd.py +28 -0
- models/jtfn.py +465 -0
- models/models.py +131 -0
- models/optimizer.py +72 -0
- models/processor.py +280 -0
- requirements.txt +4 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
data*
|
| 3 |
+
checkpoints
|
| 4 |
+
*.pyc
|
| 5 |
+
output_images
|
| 6 |
+
*.pkl
|
README.md
CHANGED
|
@@ -10,4 +10,5 @@ pinned: false
|
|
| 10 |
license: cc-by-nc-sa-4.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
| 10 |
license: cc-by-nc-sa-4.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Demo for our ICASSP 2025 paper [Convolutional Prompting for Broad-Domain Retinal Vessel Segmentation](https://arxiv.org/abs/2412.18089).
|
| 14 |
+
Please refer to [https://github.com/ruc-aimc-lab/dcp](https://github.com/ruc-aimc-lab/dcp) for more information.
|
app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from inference import Inference
|
| 3 |
+
import os
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
|
| 6 |
+
#MODEL_ID = os.getenv("MODEL_ID", "your_username/your_model_name") # 替换为你的模型ID
|
| 7 |
+
|
| 8 |
+
model_path = snapshot_download(repo_id='AIMClab-RUC/UNet_DCP_1024')
|
| 9 |
+
|
| 10 |
+
TEXT_OPTIONS = ["CFP", "UWF", "FFA", "SLO", "OCTA"]
|
| 11 |
+
|
| 12 |
+
inference_engine = Inference(model_path=model_path)
|
| 13 |
+
|
| 14 |
+
def main(image, text):
|
| 15 |
+
out = inference_engine.inference(image, text)
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
interface = gr.Interface(
|
| 19 |
+
fn=main,
|
| 20 |
+
inputs=[
|
| 21 |
+
gr.Image(type="numpy"),
|
| 22 |
+
gr.Dropdown(
|
| 23 |
+
choices=TEXT_OPTIONS,
|
| 24 |
+
label="Modality",
|
| 25 |
+
value=TEXT_OPTIONS[0]
|
| 26 |
+
)
|
| 27 |
+
],
|
| 28 |
+
outputs=gr.Image(type="numpy"),
|
| 29 |
+
title="Broad domain retinal vessel segmentation",
|
| 30 |
+
description=""
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
interface.launch()
|
inference.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example code for running inference on a pre-trained model
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
from models import build_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
|
| 11 |
+
|
| 12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 13 |
+
|
| 14 |
+
def sigmoid(arr):
|
| 15 |
+
return 1. / (1 + np.exp(-arr))
|
| 16 |
+
|
| 17 |
+
class Inference(object):
|
| 18 |
+
def __init__(self, model_path):
|
| 19 |
+
self.model_path = model_path
|
| 20 |
+
config_path = os.path.join(model_path, 'config.json')
|
| 21 |
+
with open(config_path) as fin:
|
| 22 |
+
params = json.load(fin)
|
| 23 |
+
self.model_params = params['model_params']
|
| 24 |
+
self.modality_mapping = params['modality_mapping']
|
| 25 |
+
self.model = self.load_model()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def inference(self, image, modality):
|
| 29 |
+
assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
|
| 30 |
+
|
| 31 |
+
image = self.load_image(image)
|
| 32 |
+
modality_idx = self.modality_mapping[modality]
|
| 33 |
+
modality_idx = torch.tensor([modality_idx])
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
output = self.model.predict(x=image, device=device, dataset_idx=modality_idx)
|
| 36 |
+
output = output.data.cpu().numpy()[0][0]
|
| 37 |
+
output = sigmoid(output) * 255
|
| 38 |
+
output = output.astype(np.uint8)
|
| 39 |
+
return output
|
| 40 |
+
|
| 41 |
+
def load_image(self, image):
|
| 42 |
+
# Load the image and preprocess it
|
| 43 |
+
if isinstance(image, str):
|
| 44 |
+
image = cv2.imread(image)[:, :, [2, 1, 0]]
|
| 45 |
+
#image = image
|
| 46 |
+
image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
|
| 47 |
+
image = image.astype(np.float32) / 255.0
|
| 48 |
+
image = np.transpose(image, (2, 0, 1))
|
| 49 |
+
image = np.expand_dims(image, axis=0)
|
| 50 |
+
image = torch.tensor(image)
|
| 51 |
+
return image
|
| 52 |
+
|
| 53 |
+
def load_model(self):
|
| 54 |
+
print('Loading model from {}'.format(self.model_path))
|
| 55 |
+
model = build_model(model_name=self.model_params['net'],
|
| 56 |
+
model_params=self.model_params,
|
| 57 |
+
training=False,
|
| 58 |
+
dataset_idx=list(self.modality_mapping.values()),
|
| 59 |
+
pretrained=False)
|
| 60 |
+
#print(model.model.pos_promot3['0'])
|
| 61 |
+
|
| 62 |
+
model.set_device(device)
|
| 63 |
+
# model.requires_grad_false()
|
| 64 |
+
model.load_model(os.path.join(self.model_path, 'model.pkl'))
|
| 65 |
+
model.set_mode('eval')
|
| 66 |
+
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
model_path = 'checkpoints/UNet_DCP_1024'
|
| 72 |
+
image_paths = [
|
| 73 |
+
'images/FFA.bmp',
|
| 74 |
+
'images/CFP.jpg',
|
| 75 |
+
'images/SLO.jpg',
|
| 76 |
+
'images/UWF.jpg',
|
| 77 |
+
'images/OCTA.png'
|
| 78 |
+
]
|
| 79 |
+
modalities = ['FFA', 'CFP', 'SLO', 'UWF', 'OCTA']
|
| 80 |
+
|
| 81 |
+
output_root = 'output_images'
|
| 82 |
+
os.makedirs(output_root, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
inference = Inference(model_path)
|
| 85 |
+
|
| 86 |
+
for image_path, modality in zip(image_paths, modalities):
|
| 87 |
+
output = inference.inference(image_path, modality)
|
| 88 |
+
cv2.imwrite(os.path.join(output_root, '{}.png'.format(modality)), output)
|
models/UNet_p.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def rand(size, val=0.01):
|
| 8 |
+
out = torch.zeros(size)
|
| 9 |
+
|
| 10 |
+
nn.init.uniform_(out, -val, val)
|
| 11 |
+
return out
|
| 12 |
+
|
| 13 |
+
# from medsam
|
| 14 |
+
def window_partition(x: torch.Tensor, window_size: int):
|
| 15 |
+
B, C, H, W = x.size()
|
| 16 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 17 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 18 |
+
if pad_h > 0 or pad_w > 0:
|
| 19 |
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
| 20 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 21 |
+
|
| 22 |
+
x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
|
| 23 |
+
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
|
| 24 |
+
return windows, (Hp, Wp), (Hp // window_size, Wp // window_size)
|
| 25 |
+
|
| 26 |
+
def prompt_partition(prompt: torch.Tensor, h_windows: int, w_windows: int):
|
| 27 |
+
# prompt: B, C, H, W
|
| 28 |
+
B, C, H, W = prompt.size()
|
| 29 |
+
prompt = prompt.view(B, 1, 1, C, H, W)
|
| 30 |
+
prompt = prompt.repeat((1, h_windows, w_windows, 1, 1, 1)).contiguous().view(-1, C, H, W)
|
| 31 |
+
return prompt
|
| 32 |
+
|
| 33 |
+
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]):
|
| 34 |
+
# windows: B * Hp // window_size * Wp // window_size, C, window_size, window_size
|
| 35 |
+
Hp, Wp = pad_hw
|
| 36 |
+
H, W = hw
|
| 37 |
+
B = (windows.shape[0] * window_size * window_size) // (Hp * Wp)
|
| 38 |
+
# 0 1 2 3 4 5
|
| 39 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, -1, window_size, window_size)
|
| 40 |
+
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, Hp, Wp)
|
| 41 |
+
|
| 42 |
+
if Hp > H or Wp > W:
|
| 43 |
+
x = x[:, :, :H, :W].contiguous()
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GELU(nn.Module):
|
| 48 |
+
def __init__(self):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
cdf = 0.5 * (1 + torch.erf(x / 2**0.5))
|
| 53 |
+
return x * cdf
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class OneLayerRes(nn.Module):
|
| 57 |
+
def __init__(self, in_features, out_features, kernel_size, padding) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding)
|
| 60 |
+
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
x = x + self.weight * self.conv(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MLP(nn.Module):
|
| 68 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
|
| 69 |
+
super().__init__()
|
| 70 |
+
out_features = out_features or in_features
|
| 71 |
+
hidden_features = hidden_features or in_features
|
| 72 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 73 |
+
self.act = act_layer()
|
| 74 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 75 |
+
self.drop = nn.Dropout(drop)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
x = self.fc1(x)
|
| 79 |
+
x = self.act(x)
|
| 80 |
+
x = self.drop(x)
|
| 81 |
+
x = self.fc2(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 86 |
+
def __init__(self, dim, num_heads=8, drop_rate=0.2):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.num_heads = num_heads
|
| 89 |
+
head_dim = dim // num_heads
|
| 90 |
+
self.norm = nn.LayerNorm(dim)
|
| 91 |
+
|
| 92 |
+
self.scale = head_dim ** -0.5
|
| 93 |
+
|
| 94 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
| 95 |
+
self.drop = nn.Dropout(drop_rate)
|
| 96 |
+
self.proj = nn.Linear(dim, dim)
|
| 97 |
+
|
| 98 |
+
def forward(self, x, heat=False):
|
| 99 |
+
B, N, C = x.shape
|
| 100 |
+
out = self.norm(x)
|
| 101 |
+
qkv = self.qkv(out).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 102 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 103 |
+
|
| 104 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 105 |
+
attn = attn.softmax(dim=-1)
|
| 106 |
+
attn = self.drop(attn)
|
| 107 |
+
|
| 108 |
+
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 109 |
+
out = self.proj(out)
|
| 110 |
+
out = self.drop(out)
|
| 111 |
+
out = x + out
|
| 112 |
+
if heat:
|
| 113 |
+
return out, attn
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MultiHeadAttention2D_POS(nn.Module):
|
| 118 |
+
def __init__(self, dim_q, dim_k, dim_v, embed_dim, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, slide=0):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.stride = stride
|
| 121 |
+
self.num_heads = num_heads
|
| 122 |
+
|
| 123 |
+
self.slide = slide
|
| 124 |
+
|
| 125 |
+
self.embed_dim_qk = embed_dim // embed_dim_ratio
|
| 126 |
+
|
| 127 |
+
if self.embed_dim_qk % num_heads != 0:
|
| 128 |
+
self.embed_dim_qk = (self.embed_dim_qk // num_heads + 1) * num_heads
|
| 129 |
+
|
| 130 |
+
self.embed_dim_v = embed_dim
|
| 131 |
+
if self.embed_dim_v % num_heads != 0:
|
| 132 |
+
self.embed_dim_v = (self.embed_dim_v // num_heads + 1) * num_heads
|
| 133 |
+
|
| 134 |
+
head_dim = self.embed_dim_qk // num_heads
|
| 135 |
+
|
| 136 |
+
self.scale = head_dim ** -0.5
|
| 137 |
+
|
| 138 |
+
self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
|
| 139 |
+
self.conv_k = nn.Conv2d(in_channels=dim_k, out_channels=self.embed_dim_qk, kernel_size=stride, padding=0, stride=stride)
|
| 140 |
+
self.conv_v = nn.Conv2d(in_channels=dim_v, out_channels=self.embed_dim_v, kernel_size=stride, padding=0, stride=stride)
|
| 141 |
+
|
| 142 |
+
self.drop = nn.Dropout(drop_rate)
|
| 143 |
+
self.proj_out = nn.Conv2d(in_channels=self.embed_dim_v, out_channels=dim_q, kernel_size=3, padding=1)
|
| 144 |
+
if self.stride > 1:
|
| 145 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
|
| 146 |
+
else:
|
| 147 |
+
self.upsample = nn.Identity()
|
| 148 |
+
|
| 149 |
+
self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 150 |
+
|
| 151 |
+
def forward(self, q, k, v, heat=False):
|
| 152 |
+
B, _, H_q, W_q = q.size()
|
| 153 |
+
_, _, H_kv, W_kv = k.size()
|
| 154 |
+
|
| 155 |
+
H_q = H_q // self.stride
|
| 156 |
+
W_q = W_q // self.stride
|
| 157 |
+
H_kv = H_kv // self.stride
|
| 158 |
+
W_kv = W_kv // self.stride
|
| 159 |
+
|
| 160 |
+
proj_q = self.conv_q(q).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_q * W_q).permute(0, 1, 3, 2).contiguous()
|
| 161 |
+
proj_k = self.conv_k(k).reshape(B, self.num_heads, self.embed_dim_qk // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
|
| 162 |
+
proj_v = self.conv_v(v).reshape(B, self.num_heads, self.embed_dim_v // self.num_heads, H_kv * W_kv).permute(0, 1, 3, 2).contiguous()
|
| 163 |
+
|
| 164 |
+
attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * self.scale # B, self.num_heads, H_q * W_q, H_kv * W_kv
|
| 165 |
+
attn = attn.softmax(dim=-1)
|
| 166 |
+
attn = self.drop(attn)
|
| 167 |
+
|
| 168 |
+
out = (attn @ proj_v) # B, self.num_heads, H_q * W_q, self.embed_dim // self.num_heads
|
| 169 |
+
out = out.transpose(2, 3).contiguous().reshape(B, self.embed_dim_v, H_q, W_q)
|
| 170 |
+
|
| 171 |
+
if self.slide > 0:
|
| 172 |
+
out = out[:, :, self.slide // self.stride:]
|
| 173 |
+
q = q[:, :, self.slide:]
|
| 174 |
+
|
| 175 |
+
out = self.proj_out(out)
|
| 176 |
+
out = self.upsample(out)
|
| 177 |
+
out = self.drop(out)
|
| 178 |
+
out = q + out * self.gamma
|
| 179 |
+
return out
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class MultiHeadAttention2D_CHA(nn.Module):
|
| 183 |
+
def __init__(self, dim_q, dim_kv, stride, num_heads=8, drop_rate=0.2, slide=0):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.num_heads = num_heads
|
| 186 |
+
self.stride = stride
|
| 187 |
+
self.slide = slide
|
| 188 |
+
self.dim_q_out = dim_q - slide
|
| 189 |
+
|
| 190 |
+
self.conv_q = nn.Conv2d(in_channels=dim_q, out_channels=dim_q * num_heads, kernel_size=stride, stride=stride, groups=dim_q)
|
| 191 |
+
self.conv_k = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
|
| 192 |
+
self.conv_v = nn.Conv2d(in_channels=dim_kv, out_channels=dim_kv * num_heads, kernel_size=stride, stride=stride, groups=dim_kv)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
self.drop = nn.Dropout(drop_rate)
|
| 196 |
+
self.proj_out = nn.ConvTranspose2d(in_channels=self.dim_q_out * num_heads, out_channels=self.dim_q_out, kernel_size=stride, stride=stride, groups=self.dim_q_out)
|
| 197 |
+
self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
|
| 198 |
+
|
| 199 |
+
def forward(self, q, k, v, heat=False):
|
| 200 |
+
B, C_q, H_q, W_q = q.size()
|
| 201 |
+
_, C_kv, H_kv, W_kv = k.size()
|
| 202 |
+
|
| 203 |
+
proj_q = self.conv_q(q).reshape(B, self.num_heads, C_q, -1) # batch_size * num_heads * dim_q * (H * W)
|
| 204 |
+
proj_k = self.conv_k(k).reshape(B, self.num_heads, C_kv, -1)
|
| 205 |
+
proj_v = self.conv_v(v).reshape(B, self.num_heads, C_kv, -1) # batch_size * num_heads * dim_kv * (H * W)
|
| 206 |
+
|
| 207 |
+
scale = proj_q.size(3) ** -0.5
|
| 208 |
+
attn = (proj_q @ proj_k.transpose(-2, -1)).contiguous() * scale # batch_size, num_heads, dim_q, dim_kv
|
| 209 |
+
attn = attn.softmax(dim=-1)
|
| 210 |
+
attn = self.drop(attn)
|
| 211 |
+
|
| 212 |
+
out = (attn @ proj_v) # batch_size, num_heads, dim_q, (H * W)
|
| 213 |
+
if self.slide > 0:
|
| 214 |
+
out = out[:, :, :-self.slide]
|
| 215 |
+
out = out.reshape(B, self.num_heads * self.dim_q_out, H_q // self.stride, W_q // self.stride)
|
| 216 |
+
|
| 217 |
+
out = self.proj_out(out)
|
| 218 |
+
out = self.drop(out)
|
| 219 |
+
out = q + out * self.gamma
|
| 220 |
+
return out
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class MultiHeadAttention2D_Dual2_2(nn.Module):
|
| 224 |
+
def __init__(self, dim_pos, dim_cha, embed_dim, att_fusion, num_heads=8, drop_rate=0.2, embed_dim_ratio=4, stride=1, cha_slide=0, pos_slide=0, use_conv=True):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.pos_att = MultiHeadAttention2D_POS(dim_q=dim_pos, dim_k=dim_pos, dim_v=dim_pos, embed_dim=embed_dim, num_heads=num_heads, drop_rate=drop_rate, embed_dim_ratio=embed_dim_ratio, stride=stride, slide=pos_slide)
|
| 227 |
+
self.cha_att = MultiHeadAttention2D_CHA(dim_q=dim_cha, dim_kv=dim_cha, num_heads=num_heads, drop_rate=drop_rate, slide=cha_slide, stride=stride)
|
| 228 |
+
self.att_fusion = att_fusion # concat, add
|
| 229 |
+
|
| 230 |
+
if att_fusion == 'concat':
|
| 231 |
+
channel_in = 2 * (dim_pos - cha_slide)
|
| 232 |
+
if att_fusion == 'add':
|
| 233 |
+
channel_in = (dim_pos - cha_slide)
|
| 234 |
+
channel_out = dim_pos - cha_slide
|
| 235 |
+
|
| 236 |
+
self.use_conv = use_conv
|
| 237 |
+
if use_conv:
|
| 238 |
+
self.conv_out = nn.Sequential(nn.Dropout2d(drop_rate, True), nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1))
|
| 239 |
+
else:
|
| 240 |
+
self.conv_out = nn.Identity()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def forward(self, qkv_pos, qkv_cha, heat=False):
|
| 244 |
+
if qkv_cha is None:
|
| 245 |
+
qkv_cha = qkv_pos
|
| 246 |
+
out_pos = self.pos_att(qkv_pos, qkv_pos, qkv_pos, heat)
|
| 247 |
+
out_cha = self.cha_att(qkv_cha, qkv_cha, qkv_cha, heat)
|
| 248 |
+
|
| 249 |
+
C = out_pos.size(1)
|
| 250 |
+
H = out_cha.size(2)
|
| 251 |
+
|
| 252 |
+
if self.att_fusion == 'concat':
|
| 253 |
+
out = torch.cat([out_pos[:, :, -H:], out_cha[:, :C, :]], dim=1)
|
| 254 |
+
if self.att_fusion == 'add':
|
| 255 |
+
out = (out_pos[:, :, -H:] + out_cha[:, :C, :]) / 2
|
| 256 |
+
|
| 257 |
+
out = self.conv_out(out)
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class ResMLP(MLP):
|
| 263 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.2):
|
| 264 |
+
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, act_layer=act_layer, drop=drop)
|
| 265 |
+
self.norm = nn.LayerNorm(in_features)
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
out = self.norm(x)
|
| 269 |
+
out = self.fc1(out)
|
| 270 |
+
out = self.act(out)
|
| 271 |
+
out = self.drop(out)
|
| 272 |
+
out = self.fc2(out)
|
| 273 |
+
out = out + x
|
| 274 |
+
return out
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class MHSABlock(nn.Module):
|
| 278 |
+
def __init__(self, dim, num_heads=8, drop_rate=0.2) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.mhsa = MultiHeadSelfAttention(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
|
| 281 |
+
self.mlp = ResMLP(in_features=dim, hidden_features=dim*4, out_features=dim)
|
| 282 |
+
|
| 283 |
+
def forward(self, x, heat=False):
|
| 284 |
+
|
| 285 |
+
if heat:
|
| 286 |
+
x, attn = self.mhsa(x, heat=True)
|
| 287 |
+
else:
|
| 288 |
+
x = self.mhsa(x)
|
| 289 |
+
x = self.mlp(x)
|
| 290 |
+
if heat:
|
| 291 |
+
return x, attn
|
| 292 |
+
return x
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class SelfAttentionBlocks(nn.Module):
|
| 296 |
+
def __init__(self, dim, block_num, num_heads=8, drop_rate=0.2):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.block_num = block_num
|
| 299 |
+
assert self.block_num >= 1
|
| 300 |
+
|
| 301 |
+
self.blocks = nn.ModuleList([MHSABlock(dim=dim, num_heads=num_heads, drop_rate=drop_rate)
|
| 302 |
+
for i in range(self.block_num)])
|
| 303 |
+
|
| 304 |
+
def forward(self, x, heat=False):
|
| 305 |
+
attns = []
|
| 306 |
+
for blk in self.blocks:
|
| 307 |
+
if heat:
|
| 308 |
+
x, attn = blk(x, heat=True)
|
| 309 |
+
attns.append(attn)
|
| 310 |
+
else:
|
| 311 |
+
x = blk(x)
|
| 312 |
+
if heat:
|
| 313 |
+
return x, attns
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class conv_block(nn.Module):
|
| 318 |
+
def __init__(self,ch_in,ch_out):
|
| 319 |
+
super(conv_block,self).__init__()
|
| 320 |
+
self.conv = nn.Sequential(
|
| 321 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 322 |
+
nn.BatchNorm2d(ch_out),
|
| 323 |
+
nn.ReLU(inplace=True),
|
| 324 |
+
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 325 |
+
nn.BatchNorm2d(ch_out),
|
| 326 |
+
nn.ReLU(inplace=True)
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def forward(self,x):
|
| 331 |
+
x = self.conv(x)
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class up_conv(nn.Module):
|
| 336 |
+
def __init__(self,ch_in,ch_out):
|
| 337 |
+
super(up_conv,self).__init__()
|
| 338 |
+
self.up = nn.Sequential(
|
| 339 |
+
nn.Upsample(scale_factor=2),
|
| 340 |
+
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
|
| 341 |
+
nn.BatchNorm2d(ch_out),
|
| 342 |
+
nn.ReLU(inplace=True)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(self,x):
|
| 346 |
+
x = self.up(x)
|
| 347 |
+
return x
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class Recurrent_block(nn.Module):
|
| 351 |
+
def __init__(self,ch_out,t=2):
|
| 352 |
+
super(Recurrent_block,self).__init__()
|
| 353 |
+
self.t = t
|
| 354 |
+
self.ch_out = ch_out
|
| 355 |
+
self.conv = nn.Sequential(
|
| 356 |
+
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
|
| 357 |
+
nn.BatchNorm2d(ch_out),
|
| 358 |
+
nn.ReLU(inplace=True)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def forward(self,x):
|
| 362 |
+
for i in range(self.t):
|
| 363 |
+
|
| 364 |
+
if i==0:
|
| 365 |
+
x1 = self.conv(x)
|
| 366 |
+
|
| 367 |
+
x1 = self.conv(x+x1)
|
| 368 |
+
return x1
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class RRCNN_block(nn.Module):
|
| 372 |
+
def __init__(self,ch_in,ch_out,t=2):
|
| 373 |
+
super(RRCNN_block,self).__init__()
|
| 374 |
+
self.RCNN = nn.Sequential(
|
| 375 |
+
Recurrent_block(ch_out,t=t),
|
| 376 |
+
Recurrent_block(ch_out,t=t)
|
| 377 |
+
)
|
| 378 |
+
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
|
| 379 |
+
|
| 380 |
+
def forward(self,x):
|
| 381 |
+
x = self.Conv_1x1(x)
|
| 382 |
+
x1 = self.RCNN(x)
|
| 383 |
+
return x+x1
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class single_conv(nn.Module):
|
| 387 |
+
def __init__(self,ch_in,ch_out):
|
| 388 |
+
super(single_conv,self).__init__()
|
| 389 |
+
self.conv = nn.Sequential(
|
| 390 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
|
| 391 |
+
nn.BatchNorm2d(ch_out),
|
| 392 |
+
nn.ReLU(inplace=True)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def forward(self,x):
|
| 396 |
+
x = self.conv(x)
|
| 397 |
+
return x
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class Attention_block(nn.Module):
|
| 401 |
+
def __init__(self,F_g, F_l, F_int):
|
| 402 |
+
super(Attention_block,self).__init__()
|
| 403 |
+
self.W_g = nn.Sequential(
|
| 404 |
+
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
| 405 |
+
nn.BatchNorm2d(F_int)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.W_x = nn.Sequential(
|
| 409 |
+
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
| 410 |
+
nn.BatchNorm2d(F_int)
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
self.psi = nn.Sequential(
|
| 414 |
+
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
|
| 415 |
+
nn.BatchNorm2d(1),
|
| 416 |
+
nn.Sigmoid()
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
self.relu = nn.ReLU(inplace=True)
|
| 420 |
+
|
| 421 |
+
def forward(self,g,x):
|
| 422 |
+
g1 = self.W_g(g)
|
| 423 |
+
x1 = self.W_x(x)
|
| 424 |
+
psi = self.relu(g1+x1)
|
| 425 |
+
psi = self.psi(psi)
|
| 426 |
+
|
| 427 |
+
return x*psi
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class R2AttUNetDecoder(nn.Module):
|
| 431 |
+
def __init__(self, channels, t=2):
|
| 432 |
+
super(R2AttUNetDecoder,self).__init__()
|
| 433 |
+
|
| 434 |
+
self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 435 |
+
|
| 436 |
+
self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
|
| 437 |
+
self.Att5 = Attention_block(F_g=channels[3], F_l=channels[3], F_int=channels[3]//2)
|
| 438 |
+
self.Up_RRCNN5 = RRCNN_block(ch_in=2 * channels[3], ch_out=channels[3], t=t)
|
| 439 |
+
|
| 440 |
+
self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
|
| 441 |
+
self.Att4 = Attention_block(F_g=channels[2], F_l=channels[2], F_int=channels[2]//2)
|
| 442 |
+
self.Up_RRCNN4 = RRCNN_block(ch_in=2 * channels[2], ch_out=channels[2], t=t)
|
| 443 |
+
|
| 444 |
+
self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
|
| 445 |
+
self.Att3 = Attention_block(F_g=channels[1], F_l=channels[1], F_int=channels[1]//2)
|
| 446 |
+
self.Up_RRCNN3 = RRCNN_block(ch_in=2 * channels[1], ch_out=channels[1], t=t)
|
| 447 |
+
|
| 448 |
+
self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
|
| 449 |
+
self.Att2 = Attention_block(F_g=channels[0], F_l=channels[0], F_int=channels[0]//2)
|
| 450 |
+
self.Up_RRCNN2 = RRCNN_block(ch_in=2 * channels[0], ch_out=channels[0], t=t)
|
| 451 |
+
|
| 452 |
+
def forward(self, x1, x2, x3, x4, x5):
|
| 453 |
+
|
| 454 |
+
out = self.Up5(x5)
|
| 455 |
+
x4_att = self.Att5(g=out, x=x4)
|
| 456 |
+
out = torch.cat((x4_att, out),dim=1)
|
| 457 |
+
out = self.Up_RRCNN5(out)
|
| 458 |
+
|
| 459 |
+
out = self.Up4(out)
|
| 460 |
+
x3_att = self.Att4(g=out, x=x3)
|
| 461 |
+
out = torch.cat((x3_att, out),dim=1)
|
| 462 |
+
out = self.Up_RRCNN4(out)
|
| 463 |
+
|
| 464 |
+
out = self.Up3(out)
|
| 465 |
+
x2_att = self.Att3(g=out, x=x2)
|
| 466 |
+
out = torch.cat((x2_att, out),dim=1)
|
| 467 |
+
out = self.Up_RRCNN3(out)
|
| 468 |
+
|
| 469 |
+
out = self.Up2(out)
|
| 470 |
+
x1_att = self.Att2(g=out, x=x1)
|
| 471 |
+
out = torch.cat((x1_att, out),dim=1)
|
| 472 |
+
out = self.Up_RRCNN2(out)
|
| 473 |
+
|
| 474 |
+
out = self.Upsample(out)
|
| 475 |
+
|
| 476 |
+
return out
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class ConvBlock(nn.Module):
|
| 480 |
+
def __init__(self, ch_in, ch_out, kernel_size=3, stride=1, padding=0, bias=True):
|
| 481 |
+
super(ConvBlock, self).__init__()
|
| 482 |
+
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
| 483 |
+
self.bn1 = nn.BatchNorm2d(ch_out)
|
| 484 |
+
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
| 485 |
+
self.bn2 = nn.BatchNorm2d(ch_out)
|
| 486 |
+
self.activate = nn.LeakyReLU(negative_slope=0.01)
|
| 487 |
+
|
| 488 |
+
for m in self.modules():
|
| 489 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 490 |
+
nn.init.kaiming_normal_(m.weight)
|
| 491 |
+
if m.bias is not None:
|
| 492 |
+
nn.init.constant_(m.bias, 0)
|
| 493 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 494 |
+
nn.init.constant_(m.weight, 1)
|
| 495 |
+
nn.init.constant_(m.bias, 0)
|
| 496 |
+
|
| 497 |
+
def forward(self, x):
|
| 498 |
+
out = self.conv1(x)
|
| 499 |
+
out = self.bn1(out)
|
| 500 |
+
out = self.activate(out)
|
| 501 |
+
|
| 502 |
+
out = self.conv2(out)
|
| 503 |
+
out = self.bn2(out)
|
| 504 |
+
out = self.activate(out)
|
| 505 |
+
return out
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class UNetDecoder(nn.Module):
|
| 509 |
+
def __init__(self, channels):
|
| 510 |
+
super(UNetDecoder,self).__init__()
|
| 511 |
+
|
| 512 |
+
self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 513 |
+
|
| 514 |
+
self.Up5 = up_conv(ch_in=channels[4], ch_out=channels[3])
|
| 515 |
+
self.conv5 = ConvBlock(ch_in=2 * channels[3], ch_out=channels[3], kernel_size=3, stride=1, padding=1)
|
| 516 |
+
|
| 517 |
+
self.Up4 = up_conv(ch_in=channels[3], ch_out=channels[2])
|
| 518 |
+
self.conv4 = ConvBlock(ch_in=2 * channels[2], ch_out=channels[2], kernel_size=3, stride=1, padding=1)
|
| 519 |
+
|
| 520 |
+
self.Up3 = up_conv(ch_in=channels[2], ch_out=channels[1])
|
| 521 |
+
self.conv3 = ConvBlock(ch_in=2 * channels[1], ch_out=channels[1], kernel_size=3, stride=1, padding=1)
|
| 522 |
+
|
| 523 |
+
self.Up2 = up_conv(ch_in=channels[1], ch_out=channels[0])
|
| 524 |
+
self.conv2 = ConvBlock(ch_in=2 * channels[0], ch_out=channels[0], kernel_size=3, stride=1, padding=1)
|
| 525 |
+
|
| 526 |
+
def forward(self, x1, x2, x3, x4, x5):
|
| 527 |
+
|
| 528 |
+
out = self.Up5(x5)
|
| 529 |
+
out = torch.cat((x4, out),dim=1)
|
| 530 |
+
out = self.conv5(out)
|
| 531 |
+
|
| 532 |
+
out = self.Up4(out)
|
| 533 |
+
out = torch.cat((x3, out),dim=1)
|
| 534 |
+
out = self.conv4(out)
|
| 535 |
+
|
| 536 |
+
out = self.Up3(out)
|
| 537 |
+
out = torch.cat((x2, out),dim=1)
|
| 538 |
+
out = self.conv3(out)
|
| 539 |
+
|
| 540 |
+
out = self.Up2(out)
|
| 541 |
+
out = torch.cat((x1, out),dim=1)
|
| 542 |
+
out = self.conv2(out)
|
| 543 |
+
|
| 544 |
+
out = self.Upsample(out)
|
| 545 |
+
|
| 546 |
+
return out
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class U_Net_P(nn.Module):
|
| 550 |
+
def __init__(self, encoder, decoder, output_ch, num_classes):
|
| 551 |
+
super(U_Net_P, self).__init__()
|
| 552 |
+
|
| 553 |
+
self.encoder = encoder
|
| 554 |
+
self.decoder = decoder
|
| 555 |
+
|
| 556 |
+
self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def forward(self, x):
|
| 560 |
+
# encoding path
|
| 561 |
+
x1, x2, x3, x4, x5 = self.encoder(x)
|
| 562 |
+
x = self.decoder(x1, x2, x3, x4, x5)
|
| 563 |
+
x = self.Last_Conv(x)
|
| 564 |
+
|
| 565 |
+
return x
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class Prompt_U_Net_P_DCP(nn.Module):
|
| 569 |
+
def __init__(self, encoder, decoder, output_ch, num_classes, dataset_idx, encoder_channels, prompt_init, pos_promot_channels, cha_promot_channels, embed_ratio, strides, local_window_sizes, att_fusion, use_conv):
|
| 570 |
+
super(Prompt_U_Net_P_DCP, self).__init__()
|
| 571 |
+
self.dataset_idx = dataset_idx
|
| 572 |
+
self.local_window_sizes = local_window_sizes
|
| 573 |
+
|
| 574 |
+
self.encoder = encoder
|
| 575 |
+
self.decoder = decoder
|
| 576 |
+
|
| 577 |
+
self.Last_Conv = nn.Conv2d(output_ch, num_classes, kernel_size=3, stride=1, padding=1)
|
| 578 |
+
if prompt_init == 'zero':
|
| 579 |
+
p_init = torch.zeros
|
| 580 |
+
elif prompt_init == 'one':
|
| 581 |
+
p_init = torch.ones
|
| 582 |
+
elif prompt_init == 'rand':
|
| 583 |
+
p_init = rand
|
| 584 |
+
|
| 585 |
+
else:
|
| 586 |
+
raise Exception(prompt_init)
|
| 587 |
+
|
| 588 |
+
self.pos_promot_channels = pos_promot_channels
|
| 589 |
+
pos_p1 = p_init((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]))
|
| 590 |
+
pos_p2 = p_init((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]))
|
| 591 |
+
pos_p3 = p_init((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]))
|
| 592 |
+
pos_p4 = p_init((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]))
|
| 593 |
+
pos_p5 = p_init((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]))
|
| 594 |
+
self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 595 |
+
self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 596 |
+
self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 597 |
+
self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 598 |
+
self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 599 |
+
|
| 600 |
+
self.cha_promot_channels = cha_promot_channels
|
| 601 |
+
cha_p1 = p_init((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]))
|
| 602 |
+
cha_p2 = p_init((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]))
|
| 603 |
+
cha_p3 = p_init((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]))
|
| 604 |
+
cha_p4 = p_init((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]))
|
| 605 |
+
cha_p5 = p_init((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]))
|
| 606 |
+
self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 607 |
+
self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 608 |
+
self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 609 |
+
self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 610 |
+
self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 611 |
+
|
| 612 |
+
self.strides = strides
|
| 613 |
+
|
| 614 |
+
self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 615 |
+
self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 616 |
+
self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 617 |
+
self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 618 |
+
self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 619 |
+
|
| 620 |
+
def get_cha_prompts(self, dataset_idx, batch_size):
|
| 621 |
+
if len(dataset_idx) != batch_size:
|
| 622 |
+
raise Exception(dataset_idx, self.dataset_idx, batch_size)
|
| 623 |
+
promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
|
| 624 |
+
promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
|
| 625 |
+
promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
|
| 626 |
+
promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
|
| 627 |
+
promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
|
| 628 |
+
return promots1, promots2, promots3, promots4, promots5
|
| 629 |
+
|
| 630 |
+
def get_pos_prompts(self, dataset_idx, batch_size):
|
| 631 |
+
if len(dataset_idx) != batch_size:
|
| 632 |
+
raise Exception(dataset_idx, self.dataset_idx)
|
| 633 |
+
promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
|
| 634 |
+
promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
|
| 635 |
+
promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
|
| 636 |
+
promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
|
| 637 |
+
promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
|
| 638 |
+
return promots1, promots2, promots3, promots4, promots5
|
| 639 |
+
|
| 640 |
+
def forward(self, x, dataset_idx, return_features=False):
|
| 641 |
+
|
| 642 |
+
if isinstance(dataset_idx, torch.Tensor):
|
| 643 |
+
dataset_idx = list(dataset_idx.cpu().numpy())
|
| 644 |
+
cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
|
| 645 |
+
pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
|
| 646 |
+
x1, x2, x3, x4, x5 = self.encoder(x)
|
| 647 |
+
|
| 648 |
+
if return_features:
|
| 649 |
+
pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
|
| 650 |
+
|
| 651 |
+
h1, w1 = x1.size()[2:]
|
| 652 |
+
h2, w2 = x2.size()[2:]
|
| 653 |
+
h3, w3 = x3.size()[2:]
|
| 654 |
+
h4, w4 = x4.size()[2:]
|
| 655 |
+
h5, w5 = x5.size()[2:]
|
| 656 |
+
x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
|
| 657 |
+
x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
|
| 658 |
+
x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
|
| 659 |
+
x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
|
| 660 |
+
x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
|
| 661 |
+
|
| 662 |
+
cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
|
| 663 |
+
cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
|
| 664 |
+
cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
|
| 665 |
+
cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
|
| 666 |
+
cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
|
| 667 |
+
|
| 668 |
+
pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
|
| 669 |
+
pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
|
| 670 |
+
pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
|
| 671 |
+
pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
|
| 672 |
+
pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
|
| 673 |
+
|
| 674 |
+
cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
|
| 675 |
+
pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
|
| 676 |
+
|
| 677 |
+
x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
|
| 678 |
+
|
| 679 |
+
x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
|
| 680 |
+
x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
|
| 681 |
+
x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
|
| 682 |
+
x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
|
| 683 |
+
x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
|
| 684 |
+
|
| 685 |
+
if return_features:
|
| 686 |
+
pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
|
| 687 |
+
|
| 688 |
+
return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
|
| 689 |
+
|
| 690 |
+
x = self.decoder(x1, x2, x3, x4, x5)
|
| 691 |
+
x = self.Last_Conv(x)
|
| 692 |
+
|
| 693 |
+
return x
|
| 694 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .models import build_model
|
models/backbones/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .backbones import build_backbone
|
models/backbones/backbones.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from timm.models import efficientnet, convnext
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_backbone(model_name, pretrained):
|
| 7 |
+
model = getattr(Backbones, model_name)(pretrained=pretrained)
|
| 8 |
+
return model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Backbones(object):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def efficientnet_b3_p(pretrained):
|
| 14 |
+
# channels: 24, 12, 40, 120, 384
|
| 15 |
+
# for test, pretrained can be set to False
|
| 16 |
+
model = efficientnet.efficientnet_b3_pruned(pretrained=pretrained, features_only=True)
|
| 17 |
+
|
| 18 |
+
'''
|
| 19 |
+
# pre-downloaded weights
|
| 20 |
+
cp_path = os.path.join('checkpoints', 'effnetb3_pruned-59ecf72d.pth')
|
| 21 |
+
state_dict = torch.load(cp_path, map_location=torch.device('cpu'))
|
| 22 |
+
model.load_state_dict(state_dict=state_dict, strict=False)'''
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
|
models/crit/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .focal_loss import BFLoss
|
| 2 |
+
from .mmd import MMDLinear
|
| 3 |
+
from .dice import DiceLoss, DiceBCE
|
| 4 |
+
from .get_bd import generate_BD
|
models/crit/dice.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
e = 1-10
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def dice_loss(pred, target, need_sigmoid=True):
|
| 10 |
+
assert target.size() == pred.size()
|
| 11 |
+
if need_sigmoid:
|
| 12 |
+
pred = torch.sigmoid(pred)
|
| 13 |
+
intersect = 2 * (pred * target).sum() + e
|
| 14 |
+
union = (pred * pred).sum() + (target * target).sum() + e
|
| 15 |
+
return 1 - intersect / union
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DiceLoss(nn.Module):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
def forward(self, pred, target):
|
| 23 |
+
return dice_loss(pred=pred, target=target)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DiceBCE(nn.Module):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
def forward(self, pred, target):
|
| 31 |
+
return 0.5 * dice_loss(pred=pred, target=target) + \
|
| 32 |
+
0.5 * F.binary_cross_entropy_with_logits(input=pred, target=target)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
models/crit/focal_loss.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def binary_focal_loss(pred, target, alpha=0.5, gamma=2):
|
| 6 |
+
assert pred.size() == target.size()
|
| 7 |
+
pred = torch.sigmoid(pred)
|
| 8 |
+
e = 1e-5
|
| 9 |
+
loss = alpha * target * (1 - pred) ** gamma * (pred + e).log() + (1 - alpha) * (1 - target) * pred ** gamma * (1 - pred + e).log()
|
| 10 |
+
loss = loss / (0.5 ** gamma)
|
| 11 |
+
return -loss.mean()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BFLoss(nn.Module):
|
| 15 |
+
def __init__(self, alpha=0.5, gamma=2):
|
| 16 |
+
super(BFLoss, self).__init__()
|
| 17 |
+
# alpha: the weight of fg
|
| 18 |
+
self.gamma = gamma
|
| 19 |
+
self.alpha = alpha
|
| 20 |
+
|
| 21 |
+
def forward(self, pred, target, *args, **kwargs):
|
| 22 |
+
return binary_focal_loss(pred, target, alpha=self.alpha, gamma=self.gamma)
|
| 23 |
+
|
models/crit/get_bd.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate_BD(mask):
|
| 6 |
+
#print(mask.size())
|
| 7 |
+
# img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
|
| 8 |
+
# mask = mask.float()
|
| 9 |
+
mask = torch.abs(mask - F.max_pool2d(mask, 3, 1, 1))
|
| 10 |
+
mask = mask.detach()
|
| 11 |
+
|
| 12 |
+
return mask
|
models/crit/mmd.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from math import gcd
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def mmd_linear(f_of_X, f_of_Y):
|
| 7 |
+
delta = f_of_X - f_of_Y
|
| 8 |
+
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
|
| 9 |
+
return loss
|
| 10 |
+
|
| 11 |
+
class MMDLinear(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
def forward(self, fea_source, fea_target):
|
| 16 |
+
n_s, d_s = fea_source.size()
|
| 17 |
+
n_t, d_t = fea_target.size()
|
| 18 |
+
|
| 19 |
+
assert d_s == d_t
|
| 20 |
+
|
| 21 |
+
if n_s != n_t:
|
| 22 |
+
n = int(n_s * n_t / gcd(n_s, n_t)) # 最小公倍数
|
| 23 |
+
|
| 24 |
+
fea_source = fea_source.repeat((int(n / n_s), 1))
|
| 25 |
+
fea_target = fea_target.repeat((int(n / n_t), 1))
|
| 26 |
+
return mmd_linear(fea_source, fea_target)
|
| 27 |
+
|
| 28 |
+
|
models/jtfn.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ICCV2021, Joint Topology-preserving and Feature-refinement Network for Curvilinear Structure Segmentation
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .UNet_p import MultiHeadAttention2D_Dual2_2, rand, window_partition, window_unpartition, prompt_partition, OneLayerRes
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SpatialAttention(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(SpatialAttention, self).__init__()
|
| 13 |
+
self.conv = nn.Sequential(
|
| 14 |
+
nn.Conv2d(2, 1, kernel_size=(3, 3), padding=(1, 1)),
|
| 15 |
+
nn.Conv2d(1, 1, kernel_size=(5, 5), padding=(2, 2)),
|
| 16 |
+
nn.Sigmoid()
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
| 21 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
| 22 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
| 23 |
+
x = self.conv(x)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ChannelAttention(nn.Module):
|
| 28 |
+
def __init__(self, channel, reduction=2):
|
| 29 |
+
super(ChannelAttention, self).__init__()
|
| 30 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 31 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
| 32 |
+
|
| 33 |
+
self.fc1 = nn.Conv2d(channel, channel // reduction, 1, bias=False)
|
| 34 |
+
self.fc2 = nn.Conv2d(channel // reduction, channel, 1, bias=False)
|
| 35 |
+
self.activate = nn.Sigmoid()
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
avg_out = self.fc2(self.fc1(self.avg_pool(x)))
|
| 39 |
+
max_out = self.fc2(self.fc1(self.max_pool(x)))
|
| 40 |
+
out = avg_out + max_out
|
| 41 |
+
out = self.activate(out)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class GAU(nn.Module):
|
| 46 |
+
def __init__(self, in_channels, use_gau=True, reduce_dim=False, out_channels=None):
|
| 47 |
+
super(GAU, self).__init__()
|
| 48 |
+
self.use_gau = use_gau
|
| 49 |
+
self.reduce_dim = reduce_dim
|
| 50 |
+
|
| 51 |
+
if self.reduce_dim:
|
| 52 |
+
self.down_conv = nn.Sequential(
|
| 53 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
|
| 54 |
+
nn.BatchNorm2d(out_channels),
|
| 55 |
+
nn.ReLU(inplace=True)
|
| 56 |
+
)
|
| 57 |
+
in_channels = out_channels
|
| 58 |
+
|
| 59 |
+
if self.use_gau:
|
| 60 |
+
|
| 61 |
+
self.sa = SpatialAttention()
|
| 62 |
+
self.ca = ChannelAttention(in_channels)
|
| 63 |
+
|
| 64 |
+
self.reset_gate = nn.Sequential(
|
| 65 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=2, dilation=2),
|
| 66 |
+
nn.BatchNorm2d(out_channels),
|
| 67 |
+
nn.ReLU(inplace=True),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x, y):
|
| 71 |
+
if self.reduce_dim:
|
| 72 |
+
x = self.down_conv(x)
|
| 73 |
+
|
| 74 |
+
if self.use_gau:
|
| 75 |
+
y = F.interpolate(y, x.shape[-2:], mode='bilinear', align_corners=True)
|
| 76 |
+
|
| 77 |
+
comx = x * y
|
| 78 |
+
resx = x * (1 - y) # bs, c, h, w
|
| 79 |
+
|
| 80 |
+
x_sa = self.sa(resx) # bs, 1, h, w
|
| 81 |
+
x_ca = self.ca(resx) # bs, c, 1, 1
|
| 82 |
+
|
| 83 |
+
O = self.reset_gate(comx)
|
| 84 |
+
M = x_sa * x_ca
|
| 85 |
+
|
| 86 |
+
RF = M * x + (1 - M) * O
|
| 87 |
+
else:
|
| 88 |
+
RF = x
|
| 89 |
+
return RF
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class FIM(nn.Module):
|
| 93 |
+
|
| 94 |
+
def __init__(self, in_channels, out_channels, f_channels, use_topo=True, up=True, bottom=False):
|
| 95 |
+
super(FIM, self).__init__()
|
| 96 |
+
self.use_topo = use_topo
|
| 97 |
+
self.up = up
|
| 98 |
+
self.bottom = bottom
|
| 99 |
+
|
| 100 |
+
if self.up:
|
| 101 |
+
self.up_s = nn.Sequential(
|
| 102 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
| 103 |
+
nn.BatchNorm2d(out_channels),
|
| 104 |
+
nn.ReLU(inplace=True)
|
| 105 |
+
)
|
| 106 |
+
self.up_t = nn.Sequential(
|
| 107 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
| 108 |
+
nn.BatchNorm2d(out_channels),
|
| 109 |
+
nn.ReLU(inplace=True)
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
self.up_s = nn.Sequential(
|
| 113 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 114 |
+
nn.BatchNorm2d(out_channels),
|
| 115 |
+
nn.ReLU(inplace=True)
|
| 116 |
+
)
|
| 117 |
+
self.up_t = nn.Sequential(
|
| 118 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 119 |
+
nn.BatchNorm2d(out_channels),
|
| 120 |
+
nn.ReLU(inplace=True)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.decoder_s = nn.Sequential(
|
| 124 |
+
nn.Conv2d(out_channels + f_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 125 |
+
nn.BatchNorm2d(out_channels),
|
| 126 |
+
nn.ReLU(inplace=True),
|
| 127 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 128 |
+
nn.BatchNorm2d(out_channels),
|
| 129 |
+
nn.ReLU(inplace=True)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
'''self.inner_s = nn.Sequential(
|
| 133 |
+
nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
|
| 134 |
+
nn.Sigmoid()
|
| 135 |
+
)'''
|
| 136 |
+
|
| 137 |
+
if self.bottom:
|
| 138 |
+
self.st = nn.Sequential(
|
| 139 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1),
|
| 140 |
+
nn.BatchNorm2d(in_channels),
|
| 141 |
+
nn.ReLU(inplace=True)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if self.use_topo:
|
| 145 |
+
self.decoder_t = nn.Sequential(
|
| 146 |
+
nn.Conv2d(out_channels + out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 147 |
+
nn.BatchNorm2d(out_channels),
|
| 148 |
+
nn.ReLU(inplace=True)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.s_to_t = nn.Sequential(
|
| 152 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
|
| 153 |
+
nn.BatchNorm2d(out_channels),
|
| 154 |
+
nn.ReLU(inplace=True)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.t_to_s = nn.Sequential(
|
| 158 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1),
|
| 159 |
+
nn.BatchNorm2d(out_channels),
|
| 160 |
+
nn.ReLU(inplace=True)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.res_s = nn.Sequential(
|
| 164 |
+
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, stride=1, padding=1),
|
| 165 |
+
nn.BatchNorm2d(out_channels),
|
| 166 |
+
nn.ReLU(inplace=True)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
'''self.inner_t = nn.Sequential(
|
| 170 |
+
nn.Conv2d(out_channels, 1, kernel_size=3, padding=1, bias=False),
|
| 171 |
+
nn.Sigmoid()
|
| 172 |
+
)'''
|
| 173 |
+
|
| 174 |
+
def forward(self, x_s, x_t, rf):
|
| 175 |
+
if self.use_topo:
|
| 176 |
+
if self.bottom:
|
| 177 |
+
x_t = self.st(x_t)
|
| 178 |
+
#bs, c, h, w = x_s.shape
|
| 179 |
+
x_s = self.up_s(x_s)
|
| 180 |
+
x_t = self.up_t(x_t)
|
| 181 |
+
|
| 182 |
+
# padding
|
| 183 |
+
diffY = rf.size()[2] - x_s.size()[2]
|
| 184 |
+
diffX = rf.size()[3] - x_s.size()[3]
|
| 185 |
+
|
| 186 |
+
x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
|
| 187 |
+
diffY // 2, diffY - diffY // 2])
|
| 188 |
+
x_t = F.pad(x_t, [diffX // 2, diffX - diffX // 2,
|
| 189 |
+
diffY // 2, diffY - diffY // 2])
|
| 190 |
+
|
| 191 |
+
rf_s = torch.cat((x_s, rf), dim=1)
|
| 192 |
+
s = self.decoder_s(rf_s)
|
| 193 |
+
s_t = self.s_to_t(s)
|
| 194 |
+
|
| 195 |
+
t = torch.cat((x_t, s_t), dim=1)
|
| 196 |
+
x_t = self.decoder_t(t)
|
| 197 |
+
t_s = self.t_to_s(x_t)
|
| 198 |
+
|
| 199 |
+
s_res = self.res_s(torch.cat((s, t_s), dim=1))
|
| 200 |
+
|
| 201 |
+
x_s = s + s_res
|
| 202 |
+
# t_cls = self.inner_t(x_t)
|
| 203 |
+
# s_cls = self.inner_s(x_s)
|
| 204 |
+
else:
|
| 205 |
+
x_s = self.up_s(x_s)
|
| 206 |
+
#x_b = self.up_b(x_b)
|
| 207 |
+
# padding
|
| 208 |
+
diffY = rf.size()[2] - x_s.size()[2]
|
| 209 |
+
diffX = rf.size()[3] - x_s.size()[3]
|
| 210 |
+
|
| 211 |
+
x_s = F.pad(x_s, [diffX // 2, diffX - diffX // 2,
|
| 212 |
+
diffY // 2, diffY - diffY // 2])
|
| 213 |
+
|
| 214 |
+
rf_s = torch.cat((x_s, rf), dim=1)
|
| 215 |
+
s = self.decoder_s(rf_s)
|
| 216 |
+
x_s = s
|
| 217 |
+
x_t = x_s
|
| 218 |
+
#t_cls = None
|
| 219 |
+
#s_cls = self.inner_s(x_s)
|
| 220 |
+
return x_s, x_t
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class JTFNDecoder(nn.Module):
|
| 224 |
+
def __init__(self, channels, use_topo) -> None:
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.skip_blocks = []
|
| 227 |
+
for i in range(5):
|
| 228 |
+
self.skip_blocks.append(GAU(channels[i], use_gau=True, reduce_dim=False, out_channels=channels[i]))
|
| 229 |
+
self.fims = []
|
| 230 |
+
index = 3
|
| 231 |
+
for i in range(4):
|
| 232 |
+
if i == index:
|
| 233 |
+
self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=True))
|
| 234 |
+
else:
|
| 235 |
+
self.fims.append(FIM(channels[i+1], channels[i], channels[i], use_topo=use_topo, up=True, bottom=False))
|
| 236 |
+
self.skip_blocks = nn.ModuleList(self.skip_blocks)
|
| 237 |
+
self.fims = nn.ModuleList(self.fims)
|
| 238 |
+
|
| 239 |
+
def forward(self, x1, x2, x3, x4, x5, y):
|
| 240 |
+
x1 = self.skip_blocks[0](x1, y)
|
| 241 |
+
x2 = self.skip_blocks[1](x2, y)
|
| 242 |
+
x3 = self.skip_blocks[2](x3, y)
|
| 243 |
+
x4 = self.skip_blocks[3](x4, y)
|
| 244 |
+
x5 = self.skip_blocks[4](x5, y)
|
| 245 |
+
|
| 246 |
+
x5_seg, x5_bou = x5, x5
|
| 247 |
+
|
| 248 |
+
x4_seg, x4_bou = self.fims[3](x5_seg, x5_bou, x4)
|
| 249 |
+
x3_seg, x3_bou = self.fims[2](x4_seg, x4_bou, x3)
|
| 250 |
+
x2_seg, x2_bou = self.fims[1](x3_seg, x3_bou, x2)
|
| 251 |
+
x1_seg, x1_bou = self.fims[0](x2_seg, x2_bou, x1)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
return [x1_seg, x2_seg, x3_seg, x4_seg], [x1_bou, x2_bou, x3_bou, x4_bou]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class JTFN(nn.Module):
|
| 258 |
+
def __init__(self, encoder, decoder, channels, num_classes, steps) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.encoder = encoder
|
| 261 |
+
self.decoder = decoder
|
| 262 |
+
self.num_classes = num_classes
|
| 263 |
+
self.steps = steps
|
| 264 |
+
|
| 265 |
+
self.conv_seg1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 266 |
+
self.conv_seg2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 267 |
+
self.conv_seg3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 268 |
+
self.conv_seg4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 269 |
+
|
| 270 |
+
self.conv_bou1_head = nn.Conv2d(channels[0], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 271 |
+
self.conv_bou2_head = nn.Conv2d(channels[1], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 272 |
+
self.conv_bou3_head = nn.Conv2d(channels[2], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 273 |
+
self.conv_bou4_head = nn.Conv2d(channels[3], num_classes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 274 |
+
|
| 275 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
B, C, H, W = x.shape
|
| 279 |
+
y = torch.zeros([B, self.num_classes, H, W], device=x.device)
|
| 280 |
+
|
| 281 |
+
x1, x2, x3, x4, x5 = self.encoder(x)
|
| 282 |
+
|
| 283 |
+
outputs = {}
|
| 284 |
+
for i in range(self.steps):
|
| 285 |
+
segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
|
| 286 |
+
x1_seg, x2_seg, x3_seg, x4_seg = segs
|
| 287 |
+
x1_bou, x2_bou, x3_bou, x4_bou = bous
|
| 288 |
+
|
| 289 |
+
x1_seg = self.conv_seg1_head(x1_seg)
|
| 290 |
+
x2_seg = self.conv_seg2_head(x2_seg)
|
| 291 |
+
x3_seg = self.conv_seg3_head(x3_seg)
|
| 292 |
+
x4_seg = self.conv_seg4_head(x4_seg)
|
| 293 |
+
|
| 294 |
+
x1_bou = self.conv_bou1_head(x1_bou)
|
| 295 |
+
x2_bou = self.conv_bou2_head(x2_bou)
|
| 296 |
+
x3_bou = self.conv_bou3_head(x3_bou)
|
| 297 |
+
x4_bou = self.conv_bou4_head(x4_bou)
|
| 298 |
+
|
| 299 |
+
y = x1_seg
|
| 300 |
+
outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
|
| 301 |
+
outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
|
| 302 |
+
y = self.upsample(y)
|
| 303 |
+
outputs['output'] = y
|
| 304 |
+
return outputs
|
| 305 |
+
|
| 306 |
+
def encoder_forward(self, x, dataset_idx):
|
| 307 |
+
# efficient net
|
| 308 |
+
x = self.encoder.conv_stem(x)
|
| 309 |
+
x = self.encoder.bn1(x)
|
| 310 |
+
features = []
|
| 311 |
+
if 0 in self.encoder._stage_out_idx:
|
| 312 |
+
features.append(x) # add stem out
|
| 313 |
+
for i in range(len(self.encoder.blocks)):
|
| 314 |
+
for j, l in enumerate(self.encoder.blocks[i]):
|
| 315 |
+
if j == len(self.encoder.blocks[i]) - 1 and i + 1 in self.encoder._stage_out_idx:
|
| 316 |
+
x = l(x, dataset_idx)
|
| 317 |
+
else:
|
| 318 |
+
x = l(x)
|
| 319 |
+
if i + 1 in self.encoder._stage_out_idx:
|
| 320 |
+
features.append(x)
|
| 321 |
+
return features
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class JTFN_DCP(JTFN):
|
| 326 |
+
def __init__(self, encoder, decoder, channels, num_classes, steps, dataset_idx,
|
| 327 |
+
local_window_sizes, encoder_channels, pos_promot_channels, cha_promot_channels,
|
| 328 |
+
embed_ratio, strides, att_fusion, use_conv) -> None:
|
| 329 |
+
super().__init__(encoder, decoder, channels, num_classes, steps)
|
| 330 |
+
self.dataset_idx = dataset_idx
|
| 331 |
+
self.local_window_sizes = local_window_sizes
|
| 332 |
+
|
| 333 |
+
self.pos_promot_channels = pos_promot_channels
|
| 334 |
+
pos_p1 = rand((1, encoder_channels[0], pos_promot_channels[0], local_window_sizes[0]), val=3. / encoder_channels[0] ** 0.5)
|
| 335 |
+
pos_p2 = rand((1, encoder_channels[1], pos_promot_channels[1], local_window_sizes[1]), val=3. / encoder_channels[1] ** 0.5)
|
| 336 |
+
pos_p3 = rand((1, encoder_channels[2], pos_promot_channels[2], local_window_sizes[2]), val=3. / encoder_channels[2] ** 0.5)
|
| 337 |
+
pos_p4 = rand((1, encoder_channels[3], pos_promot_channels[3], local_window_sizes[3]), val=3. / encoder_channels[3] ** 0.5)
|
| 338 |
+
pos_p5 = rand((1, encoder_channels[4], pos_promot_channels[4], local_window_sizes[4]), val=3. / encoder_channels[4] ** 0.5)
|
| 339 |
+
self.pos_promot1 = nn.ParameterDict({str(k): nn.Parameter(pos_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 340 |
+
self.pos_promot2 = nn.ParameterDict({str(k): nn.Parameter(pos_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 341 |
+
self.pos_promot3 = nn.ParameterDict({str(k): nn.Parameter(pos_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 342 |
+
self.pos_promot4 = nn.ParameterDict({str(k): nn.Parameter(pos_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 343 |
+
self.pos_promot5 = nn.ParameterDict({str(k): nn.Parameter(pos_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 344 |
+
|
| 345 |
+
self.cha_promot_channels = cha_promot_channels
|
| 346 |
+
cha_p1 = rand((1, cha_promot_channels[0], local_window_sizes[0], local_window_sizes[0]), val=3. / local_window_sizes[0])
|
| 347 |
+
cha_p2 = rand((1, cha_promot_channels[1], local_window_sizes[1], local_window_sizes[1]), val=3. / local_window_sizes[1])
|
| 348 |
+
cha_p3 = rand((1, cha_promot_channels[2], local_window_sizes[2], local_window_sizes[2]), val=3. / local_window_sizes[2])
|
| 349 |
+
cha_p4 = rand((1, cha_promot_channels[3], local_window_sizes[3], local_window_sizes[3]), val=3. / local_window_sizes[3])
|
| 350 |
+
cha_p5 = rand((1, cha_promot_channels[4], local_window_sizes[4], local_window_sizes[4]), val=3. / local_window_sizes[4])
|
| 351 |
+
self.cha_promot1 = nn.ParameterDict({str(k): nn.Parameter(cha_p1.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 352 |
+
self.cha_promot2 = nn.ParameterDict({str(k): nn.Parameter(cha_p2.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 353 |
+
self.cha_promot3 = nn.ParameterDict({str(k): nn.Parameter(cha_p3.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 354 |
+
self.cha_promot4 = nn.ParameterDict({str(k): nn.Parameter(cha_p4.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 355 |
+
self.cha_promot5 = nn.ParameterDict({str(k): nn.Parameter(cha_p5.detach().clone(), requires_grad=True) for k in dataset_idx})
|
| 356 |
+
|
| 357 |
+
self.strides = strides
|
| 358 |
+
|
| 359 |
+
self.att1 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[0], dim_cha=encoder_channels[0] + cha_promot_channels[0], embed_dim=encoder_channels[0], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[0], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 360 |
+
self.att2 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[1], dim_cha=encoder_channels[1] + cha_promot_channels[1], embed_dim=encoder_channels[1], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[1], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 361 |
+
self.att3 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[2], dim_cha=encoder_channels[2] + cha_promot_channels[2], embed_dim=encoder_channels[2], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[2], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 362 |
+
self.att4 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[3], dim_cha=encoder_channels[3] + cha_promot_channels[3], embed_dim=encoder_channels[3], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[3], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 363 |
+
self.att5 = MultiHeadAttention2D_Dual2_2(dim_pos=encoder_channels[4], dim_cha=encoder_channels[4] + cha_promot_channels[4], embed_dim=encoder_channels[4], att_fusion=att_fusion, num_heads=8, embed_dim_ratio=embed_ratio, stride=strides[4], pos_slide=0, cha_slide=0, use_conv=use_conv)
|
| 364 |
+
|
| 365 |
+
def get_cha_prompts(self, dataset_idx, batch_size):
|
| 366 |
+
if len(dataset_idx) != batch_size:
|
| 367 |
+
raise Exception(dataset_idx, self.dataset_idx, batch_size)
|
| 368 |
+
# print(dataset_idx, '***')
|
| 369 |
+
promots1 = torch.concatenate([self.cha_promot1[str(i)] for i in dataset_idx], dim=0)
|
| 370 |
+
promots2 = torch.concatenate([self.cha_promot2[str(i)] for i in dataset_idx], dim=0)
|
| 371 |
+
promots3 = torch.concatenate([self.cha_promot3[str(i)] for i in dataset_idx], dim=0)
|
| 372 |
+
promots4 = torch.concatenate([self.cha_promot4[str(i)] for i in dataset_idx], dim=0)
|
| 373 |
+
promots5 = torch.concatenate([self.cha_promot5[str(i)] for i in dataset_idx], dim=0)
|
| 374 |
+
return promots1, promots2, promots3, promots4, promots5
|
| 375 |
+
|
| 376 |
+
def get_pos_prompts(self, dataset_idx, batch_size):
|
| 377 |
+
if len(dataset_idx) != batch_size:
|
| 378 |
+
raise Exception(dataset_idx, self.dataset_idx)
|
| 379 |
+
# print(dataset_idx, '***')
|
| 380 |
+
promots1 = torch.concatenate([self.pos_promot1[str(i)] for i in dataset_idx], dim=0)
|
| 381 |
+
promots2 = torch.concatenate([self.pos_promot2[str(i)] for i in dataset_idx], dim=0)
|
| 382 |
+
promots3 = torch.concatenate([self.pos_promot3[str(i)] for i in dataset_idx], dim=0)
|
| 383 |
+
promots4 = torch.concatenate([self.pos_promot4[str(i)] for i in dataset_idx], dim=0)
|
| 384 |
+
promots5 = torch.concatenate([self.pos_promot5[str(i)] for i in dataset_idx], dim=0)
|
| 385 |
+
return promots1, promots2, promots3, promots4, promots5
|
| 386 |
+
|
| 387 |
+
def forward(self, x, dataset_idx, return_features=False):
|
| 388 |
+
if isinstance(dataset_idx, torch.Tensor):
|
| 389 |
+
dataset_idx = list(dataset_idx.cpu().numpy())
|
| 390 |
+
#print(dataset_idx)
|
| 391 |
+
cha_promots1, cha_promots2, cha_promots3, cha_promots4, cha_promots5 = self.get_cha_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
|
| 392 |
+
pos_promots1, pos_promots2, pos_promots3, pos_promots4, pos_promots5 = self.get_pos_prompts(dataset_idx=dataset_idx, batch_size=x.size(0))
|
| 393 |
+
|
| 394 |
+
B, C, H, W = x.shape
|
| 395 |
+
y = torch.zeros([B, self.num_classes, H, W], device=x.device)
|
| 396 |
+
|
| 397 |
+
x1, x2, x3, x4, x5 = self.encoder(x)
|
| 398 |
+
|
| 399 |
+
if return_features:
|
| 400 |
+
pre_x1, pre_x2, pre_x3, pre_x4, pre_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
|
| 401 |
+
h1, w1 = x1.size()[2:]
|
| 402 |
+
h2, w2 = x2.size()[2:]
|
| 403 |
+
h3, w3 = x3.size()[2:]
|
| 404 |
+
h4, w4 = x4.size()[2:]
|
| 405 |
+
h5, w5 = x5.size()[2:]
|
| 406 |
+
x1, (Hp1, Wp1), (h_win1, w_win1) = window_partition(x1, self.local_window_sizes[0])
|
| 407 |
+
x2, (Hp2, Wp2), (h_win2, w_win2) = window_partition(x2, self.local_window_sizes[1])
|
| 408 |
+
x3, (Hp3, Wp3), (h_win3, w_win3) = window_partition(x3, self.local_window_sizes[2])
|
| 409 |
+
x4, (Hp4, Wp4), (h_win4, w_win4) = window_partition(x4, self.local_window_sizes[3])
|
| 410 |
+
x5, (Hp5, Wp5), (h_win5, w_win5) = window_partition(x5, self.local_window_sizes[4])
|
| 411 |
+
|
| 412 |
+
cha_promots1 = prompt_partition(cha_promots1, h_win1, w_win1)
|
| 413 |
+
cha_promots2 = prompt_partition(cha_promots2, h_win2, w_win2)
|
| 414 |
+
cha_promots3 = prompt_partition(cha_promots3, h_win3, w_win3)
|
| 415 |
+
cha_promots4 = prompt_partition(cha_promots4, h_win4, w_win4)
|
| 416 |
+
cha_promots5 = prompt_partition(cha_promots5, h_win5, w_win5)
|
| 417 |
+
|
| 418 |
+
pos_promots1 = prompt_partition(pos_promots1, h_win1, w_win1)
|
| 419 |
+
pos_promots2 = prompt_partition(pos_promots2, h_win2, w_win2)
|
| 420 |
+
pos_promots3 = prompt_partition(pos_promots3, h_win3, w_win3)
|
| 421 |
+
pos_promots4 = prompt_partition(pos_promots4, h_win4, w_win4)
|
| 422 |
+
pos_promots5 = prompt_partition(pos_promots5, h_win5, w_win5)
|
| 423 |
+
|
| 424 |
+
#print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
|
| 425 |
+
cha_x1, cha_x2, cha_x3, cha_x4, cha_x5 = torch.cat([x1, cha_promots1], dim=1), torch.cat([x2, cha_promots2], dim=1), torch.cat([x3, cha_promots3], dim=1), torch.cat([x4, cha_promots4], dim=1), torch.cat([x5, cha_promots5], dim=1)
|
| 426 |
+
pos_x1, pos_x2, pos_x3, pos_x4, pos_x5 = torch.cat([pos_promots1, x1], dim=2), torch.cat([pos_promots2, x2], dim=2), torch.cat([pos_promots3, x3], dim=2), torch.cat([pos_promots4, x4], dim=2), torch.cat([pos_promots5, x5], dim=2)
|
| 427 |
+
|
| 428 |
+
#print(x1.size(), x2.size(), x3.size(), x4.size(), x5.size())
|
| 429 |
+
x1, x2, x3, x4, x5 = self.att1(pos_x1, cha_x1), self.att2(pos_x2, cha_x2), self.att3(pos_x3, cha_x3), self.att4(pos_x4, cha_x4), self.att5(pos_x5, cha_x5)
|
| 430 |
+
|
| 431 |
+
x1 = window_unpartition(x1, self.local_window_sizes[0], (Hp1, Wp1), (h1, w1))
|
| 432 |
+
x2 = window_unpartition(x2, self.local_window_sizes[1], (Hp2, Wp2), (h2, w2))
|
| 433 |
+
x3 = window_unpartition(x3, self.local_window_sizes[2], (Hp3, Wp3), (h3, w3))
|
| 434 |
+
x4 = window_unpartition(x4, self.local_window_sizes[3], (Hp4, Wp4), (h4, w4))
|
| 435 |
+
x5 = window_unpartition(x5, self.local_window_sizes[4], (Hp5, Wp5), (h5, w5))
|
| 436 |
+
|
| 437 |
+
if return_features:
|
| 438 |
+
pro_x1, pro_x2, pro_x3, pro_x4, pro_x5 = x1.detach().clone(), x2.detach().clone(), x3.detach().clone(), x4.detach().clone(), x5.detach().clone()
|
| 439 |
+
|
| 440 |
+
return (pre_x1, pre_x2, pre_x3, pre_x4, pre_x5), (pro_x1, pro_x2, pro_x3, pro_x4, pro_x5)
|
| 441 |
+
|
| 442 |
+
outputs = {}
|
| 443 |
+
for i in range(self.steps):
|
| 444 |
+
segs, bous = self.decoder(x1, x2, x3, x4, x5, y)
|
| 445 |
+
x1_seg, x2_seg, x3_seg, x4_seg = segs
|
| 446 |
+
x1_bou, x2_bou, x3_bou, x4_bou = bous
|
| 447 |
+
|
| 448 |
+
x1_seg = self.conv_seg1_head(x1_seg)
|
| 449 |
+
x2_seg = self.conv_seg2_head(x2_seg)
|
| 450 |
+
x3_seg = self.conv_seg3_head(x3_seg)
|
| 451 |
+
x4_seg = self.conv_seg4_head(x4_seg)
|
| 452 |
+
|
| 453 |
+
x1_bou = self.conv_bou1_head(x1_bou)
|
| 454 |
+
x2_bou = self.conv_bou2_head(x2_bou)
|
| 455 |
+
x3_bou = self.conv_bou3_head(x3_bou)
|
| 456 |
+
x4_bou = self.conv_bou4_head(x4_bou)
|
| 457 |
+
|
| 458 |
+
y = x1_seg
|
| 459 |
+
outputs['step_{}_seg'.format(i)] = [x1_seg, x2_seg, x3_seg, x4_seg]
|
| 460 |
+
outputs['step_{}_bou'.format(i)] = [x1_bou, x2_bou, x3_bou, x4_bou]
|
| 461 |
+
y = self.upsample(y)
|
| 462 |
+
outputs['output'] = y
|
| 463 |
+
return outputs
|
| 464 |
+
|
| 465 |
+
|
models/models.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .processor import Processor, DCPProcessor, JTFNProcessor, JTFNDCPProcessor
|
| 2 |
+
from .UNet_p import U_Net_P, R2AttUNetDecoder, UNetDecoder, Prompt_U_Net_P_DCP
|
| 3 |
+
from .jtfn import JTFN, JTFNDecoder, JTFN_DCP
|
| 4 |
+
from .backbones import build_backbone
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_model(model_name, model_params, training, dataset_idx, pretrained):
|
| 8 |
+
model = getattr(Models, model_name)(model_params=model_params, training=training, dataset_idx=dataset_idx, pretrained=pretrained)
|
| 9 |
+
return model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Models(object):
|
| 13 |
+
@staticmethod
|
| 14 |
+
def effi_b3_p_unet(model_params, training, dataset_idx, pretrained=True):
|
| 15 |
+
n_class = model_params['n_class']
|
| 16 |
+
channels = (24, 12, 40, 120, 384)
|
| 17 |
+
|
| 18 |
+
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
|
| 19 |
+
decoder = UNetDecoder(channels=channels)
|
| 20 |
+
|
| 21 |
+
seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
|
| 22 |
+
model = Processor(model=seg_net, training_params=model_params, training=training)
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def effi_b3_p_r2attunet(model_params, training, dataset_idx, pretrained=True):
|
| 28 |
+
n_class = model_params['n_class']
|
| 29 |
+
channels = (24, 12, 40, 120, 384)
|
| 30 |
+
|
| 31 |
+
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
|
| 32 |
+
decoder = R2AttUNetDecoder(channels=channels)
|
| 33 |
+
|
| 34 |
+
seg_net = U_Net_P(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class)
|
| 35 |
+
model = Processor(model=seg_net, training_params=model_params, training=training)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def effi_b3_p_jtfn(model_params, training, dataset_idx, pretrained=True):
|
| 40 |
+
n_class = model_params['n_class']
|
| 41 |
+
channels = (24, 12, 40, 120, 384)
|
| 42 |
+
steps = model_params['steps']
|
| 43 |
+
|
| 44 |
+
encoder = build_backbone('efficientnet_b3_p')
|
| 45 |
+
decoder = JTFNDecoder(channels=channels, use_topo=True)
|
| 46 |
+
|
| 47 |
+
seg_net = JTFN(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps)
|
| 48 |
+
model = JTFNProcessor(model=seg_net, training_params=model_params, training=training)
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def prompt_effi_b3_p_unet_dcp(model_params, training, dataset_idx, pretrained=True):
|
| 54 |
+
n_class = model_params['n_class']
|
| 55 |
+
channels = [24, 12, 40, 120, 384]
|
| 56 |
+
|
| 57 |
+
cha_promot_channels = model_params['cha_promot_channels']
|
| 58 |
+
pos_promot_channels = model_params['pos_promot_channels']
|
| 59 |
+
local_window_sizes = model_params['local_window_sizes']
|
| 60 |
+
att_fusion = model_params['att_fusion']
|
| 61 |
+
prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
|
| 62 |
+
embed_ratio = model_params['embed_ratio']
|
| 63 |
+
strides = model_params['strides']
|
| 64 |
+
use_conv = model_params['use_conv']
|
| 65 |
+
|
| 66 |
+
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
|
| 67 |
+
decoder = UNetDecoder(channels=channels)
|
| 68 |
+
|
| 69 |
+
seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
|
| 70 |
+
dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
|
| 71 |
+
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
|
| 72 |
+
embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
|
| 73 |
+
att_fusion=att_fusion, use_conv=use_conv)
|
| 74 |
+
|
| 75 |
+
model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def prompt_effi_b3_p_r2attunet_dcp(model_params, training, dataset_idx, pretrained=True):
|
| 80 |
+
n_class = model_params['n_class']
|
| 81 |
+
channels = [24, 12, 40, 120, 384]
|
| 82 |
+
|
| 83 |
+
cha_promot_channels = model_params['cha_promot_channels']
|
| 84 |
+
pos_promot_channels = model_params['pos_promot_channels']
|
| 85 |
+
local_window_sizes = model_params['local_window_sizes']
|
| 86 |
+
att_fusion = model_params['att_fusion']
|
| 87 |
+
prompt_init = model_params.get('prompt_init', 'rand') # rand, zero, one
|
| 88 |
+
embed_ratio = model_params['embed_ratio']
|
| 89 |
+
strides = model_params['strides']
|
| 90 |
+
use_conv = model_params['use_conv']
|
| 91 |
+
|
| 92 |
+
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
|
| 93 |
+
decoder = R2AttUNetDecoder(channels=channels)
|
| 94 |
+
|
| 95 |
+
seg_net = Prompt_U_Net_P_DCP(encoder=encoder, decoder=decoder, output_ch=channels[0], num_classes=n_class,
|
| 96 |
+
dataset_idx=dataset_idx, encoder_channels=channels, prompt_init=prompt_init,
|
| 97 |
+
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
|
| 98 |
+
embed_ratio=embed_ratio, strides=strides, local_window_sizes=local_window_sizes,
|
| 99 |
+
att_fusion=att_fusion, use_conv=use_conv)
|
| 100 |
+
|
| 101 |
+
model = DCPProcessor(model=seg_net, training_params=model_params, training=training)
|
| 102 |
+
return model
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def prompt_effi_b3_p_jtfn_dcp(model_params, training, dataset_idx, pretrained=True):
|
| 107 |
+
n_class = model_params['n_class']
|
| 108 |
+
steps = model_params['steps']
|
| 109 |
+
channels = [24, 12, 40, 120, 384]
|
| 110 |
+
|
| 111 |
+
cha_promot_channels = model_params['cha_promot_channels']
|
| 112 |
+
pos_promot_channels = model_params['pos_promot_channels']
|
| 113 |
+
local_window_sizes = model_params['local_window_sizes']
|
| 114 |
+
att_fusion = model_params['att_fusion']
|
| 115 |
+
embed_ratio = model_params['embed_ratio']
|
| 116 |
+
strides = model_params['strides']
|
| 117 |
+
use_conv = model_params['use_conv']
|
| 118 |
+
|
| 119 |
+
encoder = build_backbone('efficientnet_b3_p', pretrained=pretrained)
|
| 120 |
+
decoder = JTFNDecoder(channels=channels, use_topo=True)
|
| 121 |
+
seg_net = JTFN_DCP(encoder=encoder, decoder=decoder, channels=channels, num_classes=n_class, steps=steps,
|
| 122 |
+
dataset_idx=dataset_idx, local_window_sizes=local_window_sizes,
|
| 123 |
+
encoder_channels=channels,
|
| 124 |
+
cha_promot_channels=cha_promot_channels, pos_promot_channels=pos_promot_channels,
|
| 125 |
+
embed_ratio=embed_ratio, strides=strides,
|
| 126 |
+
att_fusion=att_fusion, use_conv=use_conv)
|
| 127 |
+
|
| 128 |
+
model = JTFNDCPProcessor(model=seg_net, training_params=model_params, training=training)
|
| 129 |
+
return model
|
| 130 |
+
|
| 131 |
+
|
models/optimizer.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.optim as optim
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import itertools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def add_full_model_gradient_clipping(optim, clip_norm_val):
|
| 8 |
+
|
| 9 |
+
class FullModelGradientClippingOptimizer(optim):
|
| 10 |
+
def step(self, closure=None):
|
| 11 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
| 12 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
| 13 |
+
super().step(closure=closure)
|
| 14 |
+
|
| 15 |
+
return FullModelGradientClippingOptimizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Optimizer(object):
|
| 19 |
+
def __init__(self, models, training_params, sep_lr=None, sep_params=None, gradient_clip=0):
|
| 20 |
+
|
| 21 |
+
params = []
|
| 22 |
+
for model in models:
|
| 23 |
+
if isinstance(model, nn.Parameter):
|
| 24 |
+
params += [model]
|
| 25 |
+
else:
|
| 26 |
+
params += list(model.parameters())
|
| 27 |
+
if sep_lr is not None:
|
| 28 |
+
print(sep_lr)
|
| 29 |
+
add_params = []
|
| 30 |
+
for model in sep_params:
|
| 31 |
+
if isinstance(model, nn.Parameter):
|
| 32 |
+
add_params += [model]
|
| 33 |
+
else:
|
| 34 |
+
add_params += list(model.parameters())
|
| 35 |
+
params = [{'params': params},
|
| 36 |
+
{'params': add_params, 'lr': sep_lr}]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
self.lr = training_params['lr']
|
| 40 |
+
self.weight_decay = training_params['weight_decay']
|
| 41 |
+
method = training_params['optimizer']
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if method == 'SGD':
|
| 45 |
+
self.momentum = training_params['momentum']
|
| 46 |
+
if gradient_clip > 0:
|
| 47 |
+
self.optim = add_full_model_gradient_clipping(optim.SGD, gradient_clip)(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
|
| 48 |
+
else:
|
| 49 |
+
self.optim = optim.SGD(params, lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
|
| 50 |
+
elif method == 'AdamW':
|
| 51 |
+
self.optim = optim.AdamW(params, lr=self.lr, weight_decay=self.weight_decay)
|
| 52 |
+
else:
|
| 53 |
+
raise Exception('{} is not supported'.format(method))
|
| 54 |
+
|
| 55 |
+
schedule_name = training_params['lr_schedule']
|
| 56 |
+
schedule_params = training_params['schedule_params']
|
| 57 |
+
if schedule_name == 'CosineAnnealingLR':
|
| 58 |
+
schedule_params['T_max'] = training_params['inter_val'] * 4
|
| 59 |
+
self.lr_schedule = getattr(optim.lr_scheduler, schedule_name)(self.optim, **schedule_params)
|
| 60 |
+
|
| 61 |
+
def update_lr(self):
|
| 62 |
+
self.lr_schedule.step()
|
| 63 |
+
|
| 64 |
+
def z_grad(self):
|
| 65 |
+
self.optim.zero_grad()
|
| 66 |
+
|
| 67 |
+
def g_step(self):
|
| 68 |
+
self.optim.step()
|
| 69 |
+
|
| 70 |
+
def get_lr(self):
|
| 71 |
+
for param_group in self.optim.param_groups:
|
| 72 |
+
return param_group['lr']
|
models/processor.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .optimizer import Optimizer
|
| 4 |
+
from .crit import DiceBCE, generate_BD
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BasicProcessor(object):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def fit(self):
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
def predict(self):
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
def set_mode(self, mode):
|
| 21 |
+
if mode == 'train':
|
| 22 |
+
self.model.train()
|
| 23 |
+
elif mode == 'eval':
|
| 24 |
+
self.model.eval()
|
| 25 |
+
else:
|
| 26 |
+
raise Exception('Invalid model mode {}'.format(mode))
|
| 27 |
+
|
| 28 |
+
def requires_grad_false(self):
|
| 29 |
+
for param in self.model.parameters():
|
| 30 |
+
param.requires_grad = False
|
| 31 |
+
|
| 32 |
+
def set_device(self, device):
|
| 33 |
+
# print(device)
|
| 34 |
+
if isinstance(device, list):
|
| 35 |
+
if len(device) > 1:
|
| 36 |
+
self.model= nn.DataParallel(self.model, device_ids=device)
|
| 37 |
+
_device = 'cuda'
|
| 38 |
+
else:
|
| 39 |
+
_device = 'cuda:{}'.format(device[0])
|
| 40 |
+
self.model.to(_device)
|
| 41 |
+
else:
|
| 42 |
+
self.model.to(device)
|
| 43 |
+
|
| 44 |
+
def save_model(self, path):
|
| 45 |
+
torch.save(self.model.state_dict(), path)
|
| 46 |
+
|
| 47 |
+
def load_model(self, path):
|
| 48 |
+
state_dict = torch.load(path, map_location='cpu')
|
| 49 |
+
|
| 50 |
+
remove_module = True
|
| 51 |
+
for k, v in state_dict.items():
|
| 52 |
+
if not k.startswith('module.'):
|
| 53 |
+
remove_module = False
|
| 54 |
+
break
|
| 55 |
+
if remove_module:
|
| 56 |
+
# create new OrderedDict that does not contain `module.`
|
| 57 |
+
new_state_dict = OrderedDict()
|
| 58 |
+
for k, v in state_dict.items():
|
| 59 |
+
name = k[7:] #remove 'module'
|
| 60 |
+
new_state_dict[name] = v
|
| 61 |
+
|
| 62 |
+
msg = self.model.load_state_dict(new_state_dict)
|
| 63 |
+
else:
|
| 64 |
+
msg = self.model.load_state_dict(state_dict)
|
| 65 |
+
print(msg)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Processor(BasicProcessor):
|
| 69 |
+
def __init__(self, model, training_params, training) -> None:
|
| 70 |
+
self.model = model
|
| 71 |
+
|
| 72 |
+
if training:
|
| 73 |
+
self.opt = Optimizer([self.model], training_params)
|
| 74 |
+
self.crit = DiceBCE()
|
| 75 |
+
|
| 76 |
+
def fit(self, xs, ys, device, **kwargs):
|
| 77 |
+
self.opt.z_grad()
|
| 78 |
+
|
| 79 |
+
if len(device) > 1:
|
| 80 |
+
_device = 'cuda'
|
| 81 |
+
else:
|
| 82 |
+
_device = 'cuda:{}'.format(device[0])
|
| 83 |
+
xs = xs.type(torch.FloatTensor).to(_device)
|
| 84 |
+
ys = ys.type(torch.FloatTensor).to(_device)
|
| 85 |
+
|
| 86 |
+
scores = self.model(xs)
|
| 87 |
+
loss = self.crit(scores, ys)
|
| 88 |
+
|
| 89 |
+
loss.backward()
|
| 90 |
+
self.opt.g_step()
|
| 91 |
+
self.opt.update_lr()
|
| 92 |
+
|
| 93 |
+
return scores, loss
|
| 94 |
+
|
| 95 |
+
def predict(self, x, device, **kwargs):
|
| 96 |
+
if len(device) > 1:
|
| 97 |
+
_device = 'cuda'
|
| 98 |
+
else:
|
| 99 |
+
_device = 'cuda:{}'.format(device[0])
|
| 100 |
+
x = x.type(torch.FloatTensor).to(_device)
|
| 101 |
+
return self.model(x)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class DCPProcessor(BasicProcessor):
|
| 105 |
+
def __init__(self, model, training_params, training=True) -> None:
|
| 106 |
+
self.model = model
|
| 107 |
+
if training:
|
| 108 |
+
if 'prompt_lr' in training_params:
|
| 109 |
+
prompt_lr = training_params['prompt_lr']
|
| 110 |
+
self.opt = Optimizer([self.model.encoder, self.model.decoder, self.model.Last_Conv, self.model.att1, self.model.att2, self.model.att3, self.model.att4, self.model.att5], training_params,
|
| 111 |
+
sep_lr=prompt_lr, sep_params=[self.model.cha_promot1, self.model.cha_promot2, self.model.cha_promot3, self.model.cha_promot4, self.model.cha_promot5, self.model.pos_promot1, self.model.pos_promot2, self.model.pos_promot3, self.model.pos_promot4, self.model.pos_promot5])
|
| 112 |
+
else:
|
| 113 |
+
self.opt = Optimizer([self.model], training_params)
|
| 114 |
+
self.crit = DiceBCE()
|
| 115 |
+
|
| 116 |
+
def fit(self, xs, ys, device, **kwargs):
|
| 117 |
+
dataset_idx = kwargs['dataset_idx']
|
| 118 |
+
self.opt.z_grad()
|
| 119 |
+
if len(device) > 1:
|
| 120 |
+
_device = 'cuda'
|
| 121 |
+
else:
|
| 122 |
+
_device = 'cuda:{}'.format(device[0])
|
| 123 |
+
|
| 124 |
+
xs = xs.type(torch.FloatTensor).to(_device)
|
| 125 |
+
ys = ys.type(torch.FloatTensor).to(_device)
|
| 126 |
+
|
| 127 |
+
scores = self.model(xs, dataset_idx)
|
| 128 |
+
loss = self.crit(scores, ys)
|
| 129 |
+
|
| 130 |
+
loss.backward()
|
| 131 |
+
|
| 132 |
+
self.opt.g_step()
|
| 133 |
+
self.opt.update_lr()
|
| 134 |
+
|
| 135 |
+
return scores, loss
|
| 136 |
+
|
| 137 |
+
def predict(self, x, device, **kwargs):
|
| 138 |
+
dataset_idx = kwargs['dataset_idx']
|
| 139 |
+
#print(dataset_idx)
|
| 140 |
+
if isinstance(device, list):
|
| 141 |
+
if len(device) > 1:
|
| 142 |
+
_device = 'cuda'
|
| 143 |
+
else:
|
| 144 |
+
_device = 'cuda:{}'.format(device[0])
|
| 145 |
+
else:
|
| 146 |
+
_device = device
|
| 147 |
+
|
| 148 |
+
x = x.type(torch.FloatTensor).to(_device)
|
| 149 |
+
|
| 150 |
+
return self.model(x, dataset_idx)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class JTFNProcessor(BasicProcessor):
|
| 154 |
+
def __init__(self, model, training_params, training=True) -> None:
|
| 155 |
+
# model_params = training_params['model_params']
|
| 156 |
+
# n_class = model_params['n_class']
|
| 157 |
+
|
| 158 |
+
self.model = model
|
| 159 |
+
self.steps = training_params['steps']
|
| 160 |
+
|
| 161 |
+
if training:
|
| 162 |
+
self.opt = Optimizer([self.model], training_params)
|
| 163 |
+
# self.crit = DiceLoss()
|
| 164 |
+
self.crit = DiceBCE()
|
| 165 |
+
|
| 166 |
+
def fit(self, xs, ys, device, **kwargs):
|
| 167 |
+
self.opt.z_grad()
|
| 168 |
+
|
| 169 |
+
#num_domains = len(xs)
|
| 170 |
+
batch_size = len(xs)
|
| 171 |
+
|
| 172 |
+
if len(device) > 1:
|
| 173 |
+
_device = 'cuda'
|
| 174 |
+
else:
|
| 175 |
+
_device = 'cuda:{}'.format(device[0])
|
| 176 |
+
#xs = torch.concatenate(xs, dim=0).type(torch.FloatTensor).to(_device)
|
| 177 |
+
#ys = torch.concatenate(ys, dim=0).type(torch.FloatTensor).to(_device)
|
| 178 |
+
xs = xs.type(torch.FloatTensor).to(_device)
|
| 179 |
+
ys = ys.type(torch.FloatTensor).to(_device)
|
| 180 |
+
|
| 181 |
+
ys_boundary = generate_BD(ys)
|
| 182 |
+
_, _, h, w = ys.size()
|
| 183 |
+
|
| 184 |
+
outputs = self.model(xs)
|
| 185 |
+
loss = 0
|
| 186 |
+
for i in range(self.steps):
|
| 187 |
+
pred_seg = outputs['step_{}_seg'.format(i)]
|
| 188 |
+
pred_bou = outputs['step_{}_bou'.format(i)]
|
| 189 |
+
|
| 190 |
+
for j in range(len(pred_seg)):
|
| 191 |
+
p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
|
| 192 |
+
p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
|
| 193 |
+
|
| 194 |
+
loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
|
| 195 |
+
loss /= len(pred_seg)
|
| 196 |
+
loss.backward()
|
| 197 |
+
self.opt.g_step()
|
| 198 |
+
self.opt.update_lr()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
scores = outputs['output']
|
| 202 |
+
# _, C, H, W = scores.size()
|
| 203 |
+
|
| 204 |
+
# scores = scores.view(num_domains, batch_size, C, H, W)
|
| 205 |
+
# scores = scores.cpu().numpy()
|
| 206 |
+
return scores, loss
|
| 207 |
+
|
| 208 |
+
def predict(self, x, device, **kwargs):
|
| 209 |
+
if len(device) > 1:
|
| 210 |
+
_device = 'cuda'
|
| 211 |
+
else:
|
| 212 |
+
_device = 'cuda:{}'.format(device[0])
|
| 213 |
+
x = x.type(torch.FloatTensor).to(_device)
|
| 214 |
+
outputs = self.model(x)
|
| 215 |
+
|
| 216 |
+
return outputs['output']
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class JTFNDCPProcessor(BasicProcessor):
|
| 220 |
+
def __init__(self, model, training_params, training=True) -> None:
|
| 221 |
+
# model_params = training_params['model_params']
|
| 222 |
+
# n_class = model_params['n_class']
|
| 223 |
+
|
| 224 |
+
self.model = model
|
| 225 |
+
self.steps = training_params['steps']
|
| 226 |
+
|
| 227 |
+
if training:
|
| 228 |
+
|
| 229 |
+
self.opt = Optimizer([self.model], training_params)
|
| 230 |
+
# self.crit = DiceLoss()
|
| 231 |
+
self.crit = DiceBCE()
|
| 232 |
+
|
| 233 |
+
def fit(self, xs, ys, device, **kwargs):
|
| 234 |
+
dataset_idx = kwargs['dataset_idx']
|
| 235 |
+
self.opt.z_grad()
|
| 236 |
+
|
| 237 |
+
if len(device) > 1:
|
| 238 |
+
_device = 'cuda'
|
| 239 |
+
else:
|
| 240 |
+
_device = 'cuda:{}'.format(device[0])
|
| 241 |
+
xs = xs.type(torch.FloatTensor).to(_device)
|
| 242 |
+
ys = ys.type(torch.FloatTensor).to(_device)
|
| 243 |
+
|
| 244 |
+
ys_boundary = generate_BD(ys)
|
| 245 |
+
_, _, h, w = ys.size()
|
| 246 |
+
|
| 247 |
+
outputs = self.model(xs, dataset_idx)
|
| 248 |
+
loss = 0
|
| 249 |
+
for i in range(self.steps):
|
| 250 |
+
pred_seg = outputs['step_{}_seg'.format(i)]
|
| 251 |
+
pred_bou = outputs['step_{}_bou'.format(i)]
|
| 252 |
+
|
| 253 |
+
for j in range(len(pred_seg)):
|
| 254 |
+
p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
|
| 255 |
+
p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
|
| 256 |
+
|
| 257 |
+
loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
|
| 258 |
+
loss /= len(pred_seg)
|
| 259 |
+
loss.backward()
|
| 260 |
+
self.opt.g_step()
|
| 261 |
+
self.opt.update_lr()
|
| 262 |
+
|
| 263 |
+
scores = outputs['output']
|
| 264 |
+
|
| 265 |
+
return scores, loss
|
| 266 |
+
|
| 267 |
+
def predict(self, x, device, **kwargs):
|
| 268 |
+
dataset_idx = kwargs['dataset_idx']
|
| 269 |
+
if len(device) > 1:
|
| 270 |
+
_device = 'cuda'
|
| 271 |
+
else:
|
| 272 |
+
_device = 'cuda:{}'.format(device[0])
|
| 273 |
+
x = x.type(torch.FloatTensor).to(_device)
|
| 274 |
+
|
| 275 |
+
outputs = self.model(x, dataset_idx)
|
| 276 |
+
|
| 277 |
+
return outputs['output']
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
opencv-python==4.9.0.80
|
| 3 |
+
torch==2.3.0
|
| 4 |
+
timm==1.0.3
|