UNet_DCP_1024 / models /processor.py
qijie.wei
first commit
c5f4ee2
import torch
import torch.nn as nn
from .optimizer import Optimizer
from .crit import DiceBCE, generate_BD
from collections import OrderedDict
import torch.nn.functional as F
class BasicProcessor(object):
def __init__(self) -> None:
pass
def fit(self):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def set_mode(self, mode):
if mode == 'train':
self.model.train()
elif mode == 'eval':
self.model.eval()
else:
raise Exception('Invalid model mode {}'.format(mode))
def requires_grad_false(self):
for param in self.model.parameters():
param.requires_grad = False
def set_device(self, device):
# print(device)
if isinstance(device, list):
if len(device) > 1:
self.model= nn.DataParallel(self.model, device_ids=device)
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
self.model.to(_device)
else:
self.model.to(device)
def save_model(self, path):
torch.save(self.model.state_dict(), path)
def load_model(self, path):
state_dict = torch.load(path, map_location='cpu')
remove_module = True
for k, v in state_dict.items():
if not k.startswith('module.'):
remove_module = False
break
if remove_module:
# create new OrderedDict that does not contain `module.`
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] #remove 'module'
new_state_dict[name] = v
msg = self.model.load_state_dict(new_state_dict)
else:
msg = self.model.load_state_dict(state_dict)
print(msg)
class Processor(BasicProcessor):
def __init__(self, model, training_params, training) -> None:
self.model = model
if training:
self.opt = Optimizer([self.model], training_params)
self.crit = DiceBCE()
def fit(self, xs, ys, device, **kwargs):
self.opt.z_grad()
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
xs = xs.type(torch.FloatTensor).to(_device)
ys = ys.type(torch.FloatTensor).to(_device)
scores = self.model(xs)
loss = self.crit(scores, ys)
loss.backward()
self.opt.g_step()
self.opt.update_lr()
return scores, loss
def predict(self, x, device, **kwargs):
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
x = x.type(torch.FloatTensor).to(_device)
return self.model(x)
class DCPProcessor(BasicProcessor):
def __init__(self, model, training_params, training=True) -> None:
self.model = model
if training:
if 'prompt_lr' in training_params:
prompt_lr = training_params['prompt_lr']
self.opt = Optimizer([self.model.encoder, self.model.decoder, self.model.Last_Conv, self.model.att1, self.model.att2, self.model.att3, self.model.att4, self.model.att5], training_params,
sep_lr=prompt_lr, sep_params=[self.model.cha_promot1, self.model.cha_promot2, self.model.cha_promot3, self.model.cha_promot4, self.model.cha_promot5, self.model.pos_promot1, self.model.pos_promot2, self.model.pos_promot3, self.model.pos_promot4, self.model.pos_promot5])
else:
self.opt = Optimizer([self.model], training_params)
self.crit = DiceBCE()
def fit(self, xs, ys, device, **kwargs):
dataset_idx = kwargs['dataset_idx']
self.opt.z_grad()
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
xs = xs.type(torch.FloatTensor).to(_device)
ys = ys.type(torch.FloatTensor).to(_device)
scores = self.model(xs, dataset_idx)
loss = self.crit(scores, ys)
loss.backward()
self.opt.g_step()
self.opt.update_lr()
return scores, loss
def predict(self, x, device, **kwargs):
dataset_idx = kwargs['dataset_idx']
#print(dataset_idx)
if isinstance(device, list):
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
else:
_device = device
x = x.type(torch.FloatTensor).to(_device)
return self.model(x, dataset_idx)
class JTFNProcessor(BasicProcessor):
def __init__(self, model, training_params, training=True) -> None:
# model_params = training_params['model_params']
# n_class = model_params['n_class']
self.model = model
self.steps = training_params['steps']
if training:
self.opt = Optimizer([self.model], training_params)
# self.crit = DiceLoss()
self.crit = DiceBCE()
def fit(self, xs, ys, device, **kwargs):
self.opt.z_grad()
#num_domains = len(xs)
batch_size = len(xs)
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
#xs = torch.concatenate(xs, dim=0).type(torch.FloatTensor).to(_device)
#ys = torch.concatenate(ys, dim=0).type(torch.FloatTensor).to(_device)
xs = xs.type(torch.FloatTensor).to(_device)
ys = ys.type(torch.FloatTensor).to(_device)
ys_boundary = generate_BD(ys)
_, _, h, w = ys.size()
outputs = self.model(xs)
loss = 0
for i in range(self.steps):
pred_seg = outputs['step_{}_seg'.format(i)]
pred_bou = outputs['step_{}_bou'.format(i)]
for j in range(len(pred_seg)):
p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
loss /= len(pred_seg)
loss.backward()
self.opt.g_step()
self.opt.update_lr()
scores = outputs['output']
# _, C, H, W = scores.size()
# scores = scores.view(num_domains, batch_size, C, H, W)
# scores = scores.cpu().numpy()
return scores, loss
def predict(self, x, device, **kwargs):
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
x = x.type(torch.FloatTensor).to(_device)
outputs = self.model(x)
return outputs['output']
class JTFNDCPProcessor(BasicProcessor):
def __init__(self, model, training_params, training=True) -> None:
# model_params = training_params['model_params']
# n_class = model_params['n_class']
self.model = model
self.steps = training_params['steps']
if training:
self.opt = Optimizer([self.model], training_params)
# self.crit = DiceLoss()
self.crit = DiceBCE()
def fit(self, xs, ys, device, **kwargs):
dataset_idx = kwargs['dataset_idx']
self.opt.z_grad()
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
xs = xs.type(torch.FloatTensor).to(_device)
ys = ys.type(torch.FloatTensor).to(_device)
ys_boundary = generate_BD(ys)
_, _, h, w = ys.size()
outputs = self.model(xs, dataset_idx)
loss = 0
for i in range(self.steps):
pred_seg = outputs['step_{}_seg'.format(i)]
pred_bou = outputs['step_{}_bou'.format(i)]
for j in range(len(pred_seg)):
p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True)
p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True)
loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary)
loss /= len(pred_seg)
loss.backward()
self.opt.g_step()
self.opt.update_lr()
scores = outputs['output']
return scores, loss
def predict(self, x, device, **kwargs):
dataset_idx = kwargs['dataset_idx']
if len(device) > 1:
_device = 'cuda'
else:
_device = 'cuda:{}'.format(device[0])
x = x.type(torch.FloatTensor).to(_device)
outputs = self.model(x, dataset_idx)
return outputs['output']