Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from .optimizer import Optimizer | |
| from .crit import DiceBCE, generate_BD | |
| from collections import OrderedDict | |
| import torch.nn.functional as F | |
| class BasicProcessor(object): | |
| def __init__(self) -> None: | |
| pass | |
| def fit(self): | |
| raise NotImplementedError | |
| def predict(self): | |
| raise NotImplementedError | |
| def set_mode(self, mode): | |
| if mode == 'train': | |
| self.model.train() | |
| elif mode == 'eval': | |
| self.model.eval() | |
| else: | |
| raise Exception('Invalid model mode {}'.format(mode)) | |
| def requires_grad_false(self): | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def set_device(self, device): | |
| # print(device) | |
| if isinstance(device, list): | |
| if len(device) > 1: | |
| self.model= nn.DataParallel(self.model, device_ids=device) | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| self.model.to(_device) | |
| else: | |
| self.model.to(device) | |
| def save_model(self, path): | |
| torch.save(self.model.state_dict(), path) | |
| def load_model(self, path): | |
| state_dict = torch.load(path, map_location='cpu') | |
| remove_module = True | |
| for k, v in state_dict.items(): | |
| if not k.startswith('module.'): | |
| remove_module = False | |
| break | |
| if remove_module: | |
| # create new OrderedDict that does not contain `module.` | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] #remove 'module' | |
| new_state_dict[name] = v | |
| msg = self.model.load_state_dict(new_state_dict) | |
| else: | |
| msg = self.model.load_state_dict(state_dict) | |
| print(msg) | |
| class Processor(BasicProcessor): | |
| def __init__(self, model, training_params, training) -> None: | |
| self.model = model | |
| if training: | |
| self.opt = Optimizer([self.model], training_params) | |
| self.crit = DiceBCE() | |
| def fit(self, xs, ys, device, **kwargs): | |
| self.opt.z_grad() | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| xs = xs.type(torch.FloatTensor).to(_device) | |
| ys = ys.type(torch.FloatTensor).to(_device) | |
| scores = self.model(xs) | |
| loss = self.crit(scores, ys) | |
| loss.backward() | |
| self.opt.g_step() | |
| self.opt.update_lr() | |
| return scores, loss | |
| def predict(self, x, device, **kwargs): | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| x = x.type(torch.FloatTensor).to(_device) | |
| return self.model(x) | |
| class DCPProcessor(BasicProcessor): | |
| def __init__(self, model, training_params, training=True) -> None: | |
| self.model = model | |
| if training: | |
| if 'prompt_lr' in training_params: | |
| prompt_lr = training_params['prompt_lr'] | |
| self.opt = Optimizer([self.model.encoder, self.model.decoder, self.model.Last_Conv, self.model.att1, self.model.att2, self.model.att3, self.model.att4, self.model.att5], training_params, | |
| sep_lr=prompt_lr, sep_params=[self.model.cha_promot1, self.model.cha_promot2, self.model.cha_promot3, self.model.cha_promot4, self.model.cha_promot5, self.model.pos_promot1, self.model.pos_promot2, self.model.pos_promot3, self.model.pos_promot4, self.model.pos_promot5]) | |
| else: | |
| self.opt = Optimizer([self.model], training_params) | |
| self.crit = DiceBCE() | |
| def fit(self, xs, ys, device, **kwargs): | |
| dataset_idx = kwargs['dataset_idx'] | |
| self.opt.z_grad() | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| xs = xs.type(torch.FloatTensor).to(_device) | |
| ys = ys.type(torch.FloatTensor).to(_device) | |
| scores = self.model(xs, dataset_idx) | |
| loss = self.crit(scores, ys) | |
| loss.backward() | |
| self.opt.g_step() | |
| self.opt.update_lr() | |
| return scores, loss | |
| def predict(self, x, device, **kwargs): | |
| dataset_idx = kwargs['dataset_idx'] | |
| #print(dataset_idx) | |
| if isinstance(device, list): | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| else: | |
| _device = device | |
| x = x.type(torch.FloatTensor).to(_device) | |
| return self.model(x, dataset_idx) | |
| class JTFNProcessor(BasicProcessor): | |
| def __init__(self, model, training_params, training=True) -> None: | |
| # model_params = training_params['model_params'] | |
| # n_class = model_params['n_class'] | |
| self.model = model | |
| self.steps = training_params['steps'] | |
| if training: | |
| self.opt = Optimizer([self.model], training_params) | |
| # self.crit = DiceLoss() | |
| self.crit = DiceBCE() | |
| def fit(self, xs, ys, device, **kwargs): | |
| self.opt.z_grad() | |
| #num_domains = len(xs) | |
| batch_size = len(xs) | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| #xs = torch.concatenate(xs, dim=0).type(torch.FloatTensor).to(_device) | |
| #ys = torch.concatenate(ys, dim=0).type(torch.FloatTensor).to(_device) | |
| xs = xs.type(torch.FloatTensor).to(_device) | |
| ys = ys.type(torch.FloatTensor).to(_device) | |
| ys_boundary = generate_BD(ys) | |
| _, _, h, w = ys.size() | |
| outputs = self.model(xs) | |
| loss = 0 | |
| for i in range(self.steps): | |
| pred_seg = outputs['step_{}_seg'.format(i)] | |
| pred_bou = outputs['step_{}_bou'.format(i)] | |
| for j in range(len(pred_seg)): | |
| p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True) | |
| p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True) | |
| loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary) | |
| loss /= len(pred_seg) | |
| loss.backward() | |
| self.opt.g_step() | |
| self.opt.update_lr() | |
| scores = outputs['output'] | |
| # _, C, H, W = scores.size() | |
| # scores = scores.view(num_domains, batch_size, C, H, W) | |
| # scores = scores.cpu().numpy() | |
| return scores, loss | |
| def predict(self, x, device, **kwargs): | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| x = x.type(torch.FloatTensor).to(_device) | |
| outputs = self.model(x) | |
| return outputs['output'] | |
| class JTFNDCPProcessor(BasicProcessor): | |
| def __init__(self, model, training_params, training=True) -> None: | |
| # model_params = training_params['model_params'] | |
| # n_class = model_params['n_class'] | |
| self.model = model | |
| self.steps = training_params['steps'] | |
| if training: | |
| self.opt = Optimizer([self.model], training_params) | |
| # self.crit = DiceLoss() | |
| self.crit = DiceBCE() | |
| def fit(self, xs, ys, device, **kwargs): | |
| dataset_idx = kwargs['dataset_idx'] | |
| self.opt.z_grad() | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| xs = xs.type(torch.FloatTensor).to(_device) | |
| ys = ys.type(torch.FloatTensor).to(_device) | |
| ys_boundary = generate_BD(ys) | |
| _, _, h, w = ys.size() | |
| outputs = self.model(xs, dataset_idx) | |
| loss = 0 | |
| for i in range(self.steps): | |
| pred_seg = outputs['step_{}_seg'.format(i)] | |
| pred_bou = outputs['step_{}_bou'.format(i)] | |
| for j in range(len(pred_seg)): | |
| p_seg = F.interpolate(pred_seg[j], (h, w), mode='bilinear', align_corners=True) | |
| p_bou = F.interpolate(pred_bou[j], (h, w), mode='bilinear', align_corners=True) | |
| loss += self.crit(p_seg, ys) + self.crit(p_bou, ys_boundary) | |
| loss /= len(pred_seg) | |
| loss.backward() | |
| self.opt.g_step() | |
| self.opt.update_lr() | |
| scores = outputs['output'] | |
| return scores, loss | |
| def predict(self, x, device, **kwargs): | |
| dataset_idx = kwargs['dataset_idx'] | |
| if len(device) > 1: | |
| _device = 'cuda' | |
| else: | |
| _device = 'cuda:{}'.format(device[0]) | |
| x = x.type(torch.FloatTensor).to(_device) | |
| outputs = self.model(x, dataset_idx) | |
| return outputs['output'] | |