
import numpy as np
import nibabel as nib
import os

from glob import glob
from scipy.ndimage.measurements import label
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser


# inspired by the NVIDIA nnU-Net GitHub repository available at:
# https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet


def back_to_original_labels(pred, preop):
    """
    Convert back the triplet (ET, TC, WT) to the original (NCR, ED, ET) for a given prdiction.
    :param pred: prediction
    :param preop: boolean value to check if postprocessing is applied for pre-operative segmentation.
    :return: converted prediction
    """
    enh = pred[2]
    c1 = pred[0] > 0.45
    c2 = pred[1] > 0.4
    c3 = pred[2] > 0.4
    # if the WT probability for a given voxel is less than 0.45 then its class is set to 0
    pred = (c1 > 0).astype(np.uint8)
    # otherwise if the probability for TC is less than 0.4 the voxel class is 2 (ED)
    pred[(c2 == False) * (c1 == True)] = 2
    if not preop:
        pred[pred == 1] = 4
    # finally, if probability for ET is less than 0.4 voxel has class 1 (NCR), or otherwise 4 (ET)
    pred[(c3 == True) * (c1 == True)] = 4 if preop else 1

    components, n = label(pred == 4 if preop else 1)
    for et_idx in range(1, n + 1):
        # find ET connected components
        _, counts = np.unique(pred[components == et_idx], return_counts=True)
        if (counts[0] > 1) and (counts[0] < 16) and (np.mean(enh[components == et_idx]) < 0.9):
            # for components smaller than 16 voxels with mean probability smaller than 0.9,
            # replace their class to NCR (such that voxels are still considered part of the
            # tumor core
            pred[components == et_idx] = 1 if preop else 4

    # voxels with ET
    if preop:
        et = (pred == 4)
        if (et.sum() > 0) and (et.sum() < 73) and (np.mean(enh[et]) < 0.9):
            # if there are overall less than 73 voxels with ET and their mean probability is smaller
            # than 0.9 replace all ET voxels to NCR
            pred[et] = 1

    # transpose to fit BraTS orientation
    pred = np.transpose(pred, (2, 1, 0)).astype(np.uint8)

    return pred


def prepare_preditions(example, preop):
    """
    Convert back to original BraTS labels and save as NIfTI.
    :param example: example file
    :param preop: boolean value to check if postprocessing is applied for pre-operative segmentation.
    :return: post-processed NIfTI file
    """
    fname = example[0].split("/")[-1].split(".")[0]
    preds = [np.load(f) for f in example]
    # convert back to original BraTS labels
    p = back_to_original_labels(np.mean(preds, 0), preop)

    # save as NIfTI
    img = nib.load(f"./data/BraTS2021_val/images/{fname}.nii.gz")
    nib.save(
        nib.Nifti1Image(p, img.affine, header=img.header),
        os.path.join("./results/final_preds", fname + ".nii.gz"),
    )


# retrieve args from command line
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--type", type=str, choices=["preop", "postop"], help="Choose between pre- or post-operative")


if __name__ == "__main__":
    args = parser.parse_args()
    os.makedirs("./results/final_preds")
    preds = sorted(glob(f"./results/predictions*"))
    examples = list(zip(*[sorted(glob(f"{p}/*.npy")) for p in preds]))
    print("Preparing final predictions")
    for example in examples:
        prepare_preditions(example, preop=(args.type == "preop"))
    print("Finished!")
