| | import os |
| | |
| | |
| | |
| | os.environ["nnUNet_raw"] = "./nnunet_raw" |
| | os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed" |
| | os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_5" |
| | from typing import Dict |
| | import tempfile |
| | import subprocess |
| | import SimpleITK as sitk |
| | from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
| | from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ |
| | save_json |
| | |
| | import numpy as np |
| |
|
| | from base_algorithm import BaseSynthradAlgorithm |
| | from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified |
| |
|
| | import torch |
| | import shutil |
| |
|
| | import os |
| |
|
| | os.environ["OPENBLAS_NUM_THREADS"] = "1" |
| |
|
| |
|
| |
|
| |
|
| | class SynthradAlgorithm(BaseSynthradAlgorithm): |
| | """ |
| | This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image. |
| | |
| | Author: Suraj Pai (b.pai@maastrichtuniversity.nl) |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image: |
| | assert list(input_dict.keys()) == ["image", "mask", "region"] |
| |
|
| | region = input_dict["region"] |
| | mr_sitk = input_dict["image"] |
| | mask_sitk = input_dict["mask"] |
| |
|
| | mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32") |
| | mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32") |
| |
|
| | mr_np[mask_np == 0] = 0 |
| |
|
| | preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np) |
| | preprocessed_mr_sitk.CopyInformation(mr_sitk) |
| |
|
| | if region == "Head and Neck": |
| | dataset_name = "Dataset542" |
| | result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| | plans_path = "./542_gt_nnUNetResEncUNetLPlans.json" |
| | if region == "Abdomen": |
| | dataset_name = "Dataset540" |
| | result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| | plans_path = "./540_gt_nnUNetResEncUNetLPlans.json" |
| | if region == "Thorax": |
| | dataset_name = "Dataset544" |
| | result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| | plans_path = "./544_gt_nnUNetResEncUNetLPlans.json" |
| |
|
| | predictor = nnUNetPredictor( |
| | tile_step_size=0.5, |
| | use_gaussian=True, |
| | use_mirroring=True, |
| | perform_everything_on_device=True, |
| | device=torch.device('cuda', 0), |
| | verbose=True, |
| | verbose_preprocessing=True, |
| | allow_tqdm=True |
| | ) |
| | predictor.initialize_from_trained_model_folder( |
| | join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'), |
| | use_folds=(0, 1, 2, 3, 4), |
| | checkpoint_name='checkpoint_final.pth', |
| | ) |
| |
|
| | sitk_spacing = mr_sitk.GetSpacing() |
| | sitk_origin = mr_sitk.GetOrigin() |
| | sitk_dir = mr_sitk.GetDirection() |
| |
|
| | props = { |
| | 'sitk_stuff': { |
| | 'spacing': tuple(sitk_spacing), |
| | 'origin': tuple(sitk_origin), |
| | 'direction': tuple(sitk_dir), |
| | }, |
| | 'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]] |
| | } |
| |
|
| | img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32) |
| | img = np.expand_dims(img, 0) |
| | ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) |
| | |
| | pred_path = "./TRUNCATED.nii.gz" |
| | pred_sitk = sitk.ReadImage(pred_path) |
| |
|
| | ct_mean, ct_std = get_ct_normalisation_values(plans_path) |
| | mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8) |
| |
|
| | pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk) |
| | os.remove(pred_path) |
| | shutil.rmtree("./imagesTs", ignore_errors=True) |
| | shutil.rmtree("./predictions", ignore_errors=True) |
| | return pred_sitk |
| |
|
| | if __name__ == '__main__': |
| | |
| | |
| | SynthradAlgorithm().process() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |