--- imports: - "$import glob" - "$import json" - "$import os" - "$import ignite" - "$from scipy import ndimage" - $import scripts - $import scripts.monai_utils - $import scripts.lr_scheduler - $import scripts.utils - $from monai.data.utils import list_data_collate - "$import monai.apps.deepedit.transforms" workflow_type: train input_channels: 1 output_channels: 4 output_classes: 4 #arch_ckpt_path: "$@bundle_root + '/models/dynunet_FT.pt'" #arch_ckpt: "$torch.load(@arch_ckpt_path, map_location=torch.device('cuda'))" bundle_root: "." ckpt_dir: "$@bundle_root + '/models'" output_dir: "$@bundle_root + '/eval'" dataset_dir: "/processed/Public/CT_TotalSegmentator/TS_split/test/" #"/workspace/data" data_list_file_path: "$@bundle_root + '/configs/TS_test.json'" train_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training')" val_datalist: "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation')" n_gpu: - 0 - 1 device: "$torch.device('cuda:' + str(@n_gpu[0]) if torch.cuda.is_available() else 'cpu')" device_list: "$scripts.monai_utils.get_device_list(@n_gpu)" spatial_size: - 96 - 96 - 96 spatial_dims: "$len(@spatial_size)" labels: background: 0 liver: 1 spleen: 2 pancreas: 3 network_def: _target_: monai.networks.nets.DynUNet spatial_dims: "@spatial_dims" in_channels: "@input_channels" out_channels: "@output_channels" kernel_size: - 3 - 3 - 3 - 3 - 3 - 3 strides: - 1 - 2 - 2 - 2 - 2 - - 2 - 2 - 1 upsample_kernel_size: - 2 - 2 - 2 - 2 - - 2 - 2 - 1 norm_name: "instance" deep_supervision: false res_block: true network: "$@network_def.to(@device)" loss: _target_: DiceCELoss include_background: false to_onehot_y: true softmax: true squared_pred: true batch: true smooth_nr: 1.0e-06 smooth_dr: 1.0e-06 optimizer: _target_: torch.optim.AdamW params: "$@network.parameters()" weight_decay: 1.0e-05 lr: 0.00005 max_epochs: 15 lr_scheduler: _target_: scripts.lr_scheduler.LinearWarmupCosineAnnealingLR optimizer: "@optimizer" warmup_epochs: 10 warmup_start_lr: 0.0000005 eta_min: 1.0e-08 max_epochs: "@max_epochs" image_key: image label_key: label val_interval: 2 train: deterministic_transforms: - _target_: LoadImaged keys: - "@image_key" - "@label_key" reader: ITKReader - _target_: EnsureChannelFirstd keys: - "@image_key" - "@label_key" - _target_: Orientationd keys: - "@image_key" - "@label_key" axcodes: RAS - _target_: Spacingd keys: - "@image_key" - "@label_key" pixdim: - 1.5 - 1.5 - 3.0 mode: - bilinear - nearest - _target_: scripts.monai_utils.AddLabelNamesd # monai.apps.deepedit.transforms #_mode_: "debug" keys: "@label_key" label_names: "@labels" - _target_: ScaleIntensityRanged keys: "@image_key" a_min: -250 a_max: 400 b_min: 0 b_max: 1 clip: true - _target_: CropForegroundd keys: - "@image_key" - "@label_key" source_key: "@image_key" mode: - "minimum" - "minimum" - _target_: EnsureTyped keys: - "@image_key" - "@label_key" - _target_: CastToTyped keys: "@image_key" dtype: "$torch.float32" random_transforms: - _target_: RandCropByLabelClassesd keys: - "@image_key" - "@label_key" label_key: "@label_key" # label4crop spatial_size: "@spatial_size" num_classes: 4 ratios: null allow_smaller: true num_samples: 8 # - _target_: RandSpatialCropSamplesd # keys: # - "@image_key" # - "@label_key" # roi_size: "$[int(x * 0.75) for x in @spatial_size]" # num_samples: 1 # max_roi_size: "@spatial_size" # random_center: true # random_size: true # allow_missing_keys: false - _target_: SpatialPadd keys: - "@image_key" - "@label_key" spatial_size: "@spatial_size" method: "symmetric" mode: - "minimum" - "minimum" allow_missing_keys: false - _target_: RandRotate90d keys: - "@image_key" - "@label_key" prob: 0.5 max_k: 3 allow_missing_keys: false - _target_: SelectItemsd keys: - "@image_key" - "@label_key" - "label_names" - _target_: CastToTyped keys: - "@image_key" - "@label_key" dtype: - "$torch.float32" - "$torch.uint8" - _target_: ToTensord keys: - "@image_key" - "@label_key" preprocessing: _target_: Compose transforms: "$@train#deterministic_transforms + @train#random_transforms" dataset: _target_: PersistentDataset data: "@train_datalist" transform: "@train#preprocessing" cache_dir: "$@bundle_root + '/cache'" dataloader: _target_: DataLoader dataset: "@train#dataset" batch_size: 2 shuffle: true num_workers: 4 collate_fn: $list_data_collate inferer: _target_: SimpleInferer postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: pred softmax: true - _target_: AsDiscreted keys: - pred - label argmax: - true - false to_onehot: - "@output_classes" - "@output_classes" - _target_: scripts.monai_utils.SplitPredsLabeld # monai.apps.deepedit.transforms keys: pred # dice_function: # _target_: "$engine.state.metrics['train_dice']" handlers: - _target_: LrScheduleHandler lr_scheduler: "@lr_scheduler" print_lr: true # step_transform: "@dice_function" - _target_: ValidationHandler validator: "@validate#evaluator" epoch_level: true interval: "@val_interval" - _target_: StatsHandler tag_name: train_loss output_transform: "$monai.handlers.from_engine(['loss'], first=True)" - _target_: TensorBoardStatsHandler log_dir: "@output_dir" tag_name: train_loss output_transform: "$monai.handlers.from_engine(['loss'], first=True)" key_metric: train_dice: _target_: MeanDice output_transform: "$monai.handlers.from_engine(['pred', 'label'])" include_background: false additional_metrics: liver_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])" include_background: false spleen_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])" include_background: false pancreas_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])" include_background: false trainer: _target_: scripts.monai_utils.SupervisedTrainerMGPU # SupervisedTrainer device: "@device_list" # "@device" max_epochs: "@max_epochs" train_data_loader: "@train#dataloader" network: "@network" loss_function: "@loss" # train_interaction: null optimizer: "@optimizer" inferer: "@train#inferer" postprocessing: "@train#postprocessing" key_train_metric: "@train#key_metric" additional_metrics: "@train#additional_metrics" train_handlers: "@train#handlers" amp: true validate: preprocessing: _target_: Compose transforms: "%train#deterministic_transforms" dataset: # _target_: CacheDataset # data: "@val_datalist" # transform: "@validate#preprocessing" # cache_rate: 0.025 _target_: PersistentDataset data: "@val_datalist" transform: "@validate#preprocessing" cache_dir: "$@bundle_root + '/cache'" dataloader: _target_: DataLoader dataset: "@validate#dataset" batch_size: 1 shuffle: false num_workers: 4 collate_fn: $list_data_collate inferer: _target_: SlidingWindowInferer roi_size: "@spatial_size" sw_batch_size: 4 mode: "constant" overlap: 0.5 postprocessing: "%train#postprocessing" handlers: - _target_: StatsHandler iteration_log: false - _target_: TensorBoardStatsHandler log_dir: "@output_dir" iteration_log: false - _target_: CheckpointSaver save_dir: "@ckpt_dir" save_dict: model: "@network" save_key_metric: true key_metric_filename: model_latest.pt key_metric: val_dice: _target_: MeanDice output_transform: "$monai.handlers.from_engine(['pred', 'label'])" include_background: false additional_metrics: val_liver_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_liver', 'label_liver'])" include_background: false val_spleen_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_spleen', 'label_spleen'])" include_background: false val_pancreas_dice: _target_: monai.handlers.MeanDice output_transform: "$monai.handlers.from_engine(['pred_pancreas', 'label_pancreas'])" include_background: false evaluator: _target_: SupervisedEvaluator device: "@device" val_data_loader: "@validate#dataloader" network: "@network" inferer: "@validate#inferer" postprocessing: "@validate#postprocessing" key_val_metric: "@validate#key_metric" additional_metrics: "@validate#additional_metrics" val_handlers: "@validate#handlers" amp: true initialize: - "$monai.utils.set_determinism(seed=123)" run: - "$print('Training started... ')" - "$print('output_channels: ', @output_channels )" - "$print('spatial_dims: ', @spatial_dims)" - "$print('Labels dict: ', @labels)" - "$print('Get device list: ', scripts.monai_utils.get_device_list(@n_gpu))" #- "$[print(i,': ', data['image'].shape) for i, data in enumerate(@train#dataloader)]" - "$@train#trainer.run()"