# Transfer Learning for post-operative Brain Tumor Segmentation

This repository provides all the scripts required to train the nnU-Net model on pre-operative brain MRI scans and apply transfer learning on post-operative ones.

## Table of contents

- [Theoretical Background](#theoretical-background)
    * [Transfer Learning](#transfer-learning)
    * [BraTS 2021](#brats-2021)
- [Quick Start Guide](#quick-start-guide)
    * [Pre-operative](#pre-operative)
    * [Post-operative](#post-operative)

## Theoretical Background
Deep Learning (DL) has achieved cutting edge results in several medical fields, with its applicability ranging from lesion segmentation to disease relapse prediction. Neuro-oncology, being one of those, has seen important advancements, especially in automating neuroradiology tasks such as brain tumor detection and segmentation. However, even if state of the art results have been achieved by DL methods on brain tumor segmentation on pre-operative MRI scans, hardly the same can be said of post-operative segmentation, where literature lacks of a more comprehensive study and the few proposed models still present strongly sub-human performances.

### Transfer Learning
Due to the lack of available data in clinical practice, Transfer Learning (TL) has seen a spike in popularity within the medical field, allowing to train models in absence of a large dataset by leveraging knowledge learned from other source tasks. Still, current TL techniques in medical imaging mostly implement knowledge transfer from natural imaging, usually from model trained on the ImageNet dataset. Even if some progress is done, the knowledge transferred between the two areas can be either not sufficient to achieve promising results in the medical task or make the transfer process quite unpredictable. This work arises upon the intuition that TL between pre- and post-operative brain tumor segmentation could lead to promising results by leveraging both the closeness of source and target domains, and the fact that the knowledge transfer process does not leave the medical field. 

### BraTS 2021
The RSNA ASNR MICCAI Brain Tumor Segmentation (BraTS) challenge  is a project started in 2012 in conjunction with the MICCAI conference with the goal of becoming the de facto benchmark for addressing the automated tumor sub-region segmentation from pre-operative multi-parametric Magnetic Resonance Imaging (mpMRI) scans. The BraTS 2021 dataset presents a collection of 1251 patients, each one presenting four scan modalities (FLAIR, T1, T1ce and T2), acquired with different apparati and protocols from several different institutions, thus presenting a highly heterogeneous image quality (please refer to ["The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification"](https://arxiv.org/abs/2107.02314) for further details). Here are presented the four modalities and the manual segmentation for two different patients from the BraTS 2021 dataset. The ground truth segmentation comprises the three pre-operative classes of necrotic core (blue), enhancing tumor (yellow) and edema (turquoise).

<img src="images/BraTS21_examples.png" width="1200"/>


## Quick Start Guide
The architecture implemented is the one deployed by NVIDIA data scientists which [ranked first on the BraTS 2021 validation leaderboard](https://developer.nvidia.com/blog/nvidia-data-scientists-take-top-spots-in-miccai-2021-brain-tumor-segmentation-challenge/). Please refer to ["Optimized U-Net for Brain Tumor Segmentation"](https://arxiv.org/abs/2110.03352) for a complete overview. The code is a polished version of the NVIDIA nnU-Net implementation, meaning that it loses all the generality typical of nnU-Net by presenting only the relevant structure needed for pre-operative brain tumor segmentation on the BraTS dataset. It is strongly suggested to have a look at the [NVIDIA nnU-Net official implementation on GitHub](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet) for an in-depth understanding of all technicalities.

### Pre-operative
To download the training and validation dataset, you need to have an account on https://www.synapse.org and be registered for the BraTS 2021 challenge. It is assumed as starting point that both .zip files for training and validation data are unzipped in the `/data` folder in the following way:
```
 /data
  │
  ├────BraTS2021_train
  │      ├──BraTS2021_00000
  │      │      └──BraTS2021_00000_flair.nii.gz
  │      │      └──BraTS2021_00000_t1.nii.gz
  │      │      └──BraTS2021_00000_t1ce.nii.gz
  │      │      └──BraTS2021_00000_t2.nii.gz
  │      │      └──BraTS2021_00000_seg.nii.gz
  │      ├──BraTS2021_00002
  │      │      └──BraTS2021_00002_flair.nii.gz
  │      ...    └──...
  │
  └────BraTS2021_val
         ├──BraTS2021_00001
         │      └──BraTS2021_00001_flair.nii.gz
         │      └──BraTS2021_00001_t1.nii.gz
         │      └──BraTS2021_00001_t1ce.nii.gz
         │      └──BraTS2021_00001_t2.nii.gz
         ├──BraTS2021_00002
         │      └──BraTS2021_00002_flair.nii.gz
         ...    └──...
```
As a first step, all indicated modalities are stacked in such a way that each input patient consists in a single tensor of shape (C, 240, 240, 155), where C is the number of modalities decided to take into account. Specifically, by calling
```
 python ./prepare.py [--flair] [--t1] [--t1ce] [--t2]
```
all data will be prepared as described above, considering _**only**_ the modalities specified as command line arguments. The next step consists in preprocessing all patients by cropping the volume and normalizing the intensity, including also an auxiliary one-hot-encoded channel (please refer to [the original paper](https://arxiv.org/abs/2110.03352)) by calling
```
 python ./preprocess.py --task train --ohe --exec_mode training
 python ./preprocess.py --task val --ohe --exec_mode test
```
At this point, calling `ls /data` should include the following four directories
```
 BraTS2021_train   BraTS2021_val   train_3d   val_3d
```
If that is the case, it is possible to procede with the actual training. By calling
```
 python ./main.py --deep_supervision --depth 6 --filters 64 96 128 192 256 384 512 --min_fmap 2 --scheduler --learning_rate 0.0003 --epochs 150 --fold 0 --amp --gpus 1 --task train --save_ckpt
```
the nnU-Net model will start training for 150 epochs on the data contained in `/data/train_3d`, using mixed precision if `[--amp]` is included as argument. Once training is ended, it is possible to perform inference on the unseen patients contained in `/data/val_3d` by calling
```
 python ./main.py --depth 6 --filters 64 96 128 192 256 384 512 --min_fmap 2 --gpus 1 --amp --save_preds --exec_mode predict --data /data/val_3d/test --ckpt_path /results/checkpoints/epoch=AAA-dice_mean=BB.CC.ckpt --tta
```
Inference is performed through sliding window since inputs can have arbitrary shape. Test time augmentation will be implemented if `[--tta]` is included as argument, and inference prediction will be saved if `[--save_preds]` is present. Please substitute the checkpoint path with your own (by default, .ckpt files are saved in `/results/checkpoints/foldA`). Once the process is ended, by calling
```
 python ./postprocess --type preop
```
each synthesized prediction will be re-oriented to match BraTS guidelines and post-processing is applied to bring back the overlapping regions to the standard classes of edema, enhancing tumor and necrosis (by default, output files are saved in `/results/final_preds`).

### Post-operative
This repository extends the [NVIDIA official implementation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet) by enabling transfer learning. New data can be put in the `/data/BraTS2021_train` and `/data/BraTS2021_val` folders in order to be prepared, preprocessed and fed to the network in the same way as above. Please remember to harmonize it beforehand in a BraTS-like manner, i.e. by co-registering all scans to the same anatomical template, by skull-stripping them and by renaming them as "BraTS2021_P*N*_suffix.nii.gz" where * are IDs identifying the patient and the scan respectively (please refer always to [the original paper](https://arxiv.org/abs/2107.02314) for further details). 

<img src="images/unet-brats.jpg" width="1000"/>

The argument `[--freeze]` accepts a non-negative integer which indicates how many "depth levels" to freeze during training (0 meaning fine-tuning the whole architecture). For the U-Net structure, depth levels are counted from left to right and the whole encoder-decoder structure can freeze up to `2 * [--depth] + 1` blocks (output block and deep supervision heads are excluded). TL methods often require to substitute the last layer in order to adjust the network for the new desired task. However, current code does not admit such eventuality since the Radiation Therapy Oncology Group (RTOG) guideline for post-operative glioblastoma definition is chosen as reference, which includes three classes as well: resection cavity, residual enhancing tumor and surrounding edema. After the prediction process is terminated, by calling
```
 python ./postprocess --type postop
```
post-processing will be applied, similarly to the pre-operative case, in order to re-orient predictions and bring back the overlapping regions to the standard classes of edema, enhancing tumor and resection cavity. If proprietary data present low-res scans that are desired to be kept for training only (for example, non-volumetric), the argument `[--not_val]` admits a path to a .txt file whose rows are the patient IDs ("P*") of those patients that present such acquisitions.
