
import numpy as np
import pandas as pd
import pydicom as dicom
import nibabel as nib
import os
import shutil
import zipfile
import SimpleITK as sitk


def build_t1ncr_t1core(patient_info, result_dir):
    """
    Function that builds a single T1ce mri with both manual segmentation registered in it.
    :param patient_info: pandas series with patient information
    :param result_dir: directory in which to save final burned-in T1ce mri
    """
    name = patient_info.name
    outer = str(patient_info.loc["Outer"])
    burnedin_1 = str(patient_info.loc["T1ce-core"])
    burnedin_2 = str(patient_info.loc["T1ce-ncr"])
    invert_burnedin = bool(patient_info.loc["Inv-burnedin"])
    if not os.path.isdir(os.path.join(result_dir, "t1ncr")):
        os.mkdir(os.path.join(result_dir, "t1ncr"))
    if not os.path.isdir(os.path.join(result_dir, "t1core")):
        os.mkdir(os.path.join(result_dir, "t1core"))

    # build T1 segmented with core
    t1_dir = os.path.join(data_dir, name, outer, burnedin_1)
    numslices_12 = len(os.listdir(t1_dir))
    for file in os.listdir(t1_dir):
        ds = dicom.dcmread(os.path.join(t1_dir, file))
        if invert_burnedin:
            ds.InstanceNumber = numslices_12 - ds.InstanceNumber + 1
        ds.save_as(os.path.join(result_dir, "t1core", file))

    # convert to NIfTI
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(os.path.join(result_dir, "t1core"))
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    sitk.WriteImage(image, os.path.join(data_dir, name, "ready_to_CaPTk_T1core.nii.gz"))

    # build T1 segmented with ncr
    t1_dir = os.path.join(data_dir, name, outer, burnedin_2)
    for file in os.listdir(t1_dir):
        ds = dicom.dcmread(os.path.join(t1_dir, file))
        if invert_burnedin:
            ds.InstanceNumber = numslices_12 - ds.InstanceNumber + 1
        ds.save_as(os.path.join(result_dir, "t1ncr", file))

    # convert to NIfTI
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(os.path.join(result_dir, "t1ncr"))
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    sitk.WriteImage(image, os.path.join(data_dir, name, "ready_to_CaPTk_T1ncr.nii.gz"))


def build_t1ce_t2_flair(patient_info, result_dir):
    """
    Function that builds these DICOM (eventually inverting its direction) in order to be ready to be fed to CaPTk.
    :param patient_info: pandas series with patient information
    :param result_dir: directory in which to save final remaining mri
    :return: name of the file corresponding to first image in DICOM series (t1ce, t2, flair respectively)
    """
    name = patient_info.name
    outer = str(patient_info.loc["Outer"])
    base_1 = str(patient_info.loc["T1ce"])
    base_3 = str(patient_info.loc["FLAIR"])
    burnedin_3 = str(patient_info.loc["FLAIR-whole"])
    invert_base = bool(patient_info.loc["Inv-base"])
    invert_burnedin = bool(patient_info.loc["Inv-burnedin"])
    if not os.path.isdir(os.path.join(result_dir, "t1ce")):
        os.mkdir(os.path.join(result_dir, "t1ce"))
    if not os.path.isdir(os.path.join(result_dir, "t2")):
        os.mkdir(os.path.join(result_dir, "t2"))
    if not os.path.isdir(os.path.join(result_dir, "flair")):
        os.mkdir(os.path.join(result_dir, "flair"))

    # build T1ce and FLAIR keeping them as given (eventually inverting reading direction)
    for base in [base_1, base_3]:
        input_dir = os.path.join(data_dir, name, outer, base)
        numslices = len(os.listdir(input_dir))
        modality = "t1ce" if base == base_1 else "flair"
        for file in os.listdir(input_dir):
            ds = dicom.dcmread(os.path.join(input_dir, file))
            if invert_base:
                ds.InstanceNumber = numslices - ds.InstanceNumber + 1
            ds.save_as(os.path.join(result_dir, modality, file))
        # convert to NIfTI
        reader = sitk.ImageSeriesReader()
        dicom_names = reader.GetGDCMSeriesFileNames(os.path.join(result_dir, modality))
        reader.SetFileNames(dicom_names)
        image = reader.Execute()
        sitk.WriteImage(image, os.path.join(data_dir, name, f"ready_to_CaPTk_{modality.upper()}.nii.gz"))

    # build T2 as FLAIR with manual segmentation (eventually inverting reading direction)
    t2_dir = os.path.join(data_dir, name, outer, burnedin_3)
    numslices_3 = len(os.listdir(t2_dir))
    for file in os.listdir(t2_dir):
        ds = dicom.dcmread(os.path.join(t2_dir, file))
        if invert_burnedin:
            ds.InstanceNumber = numslices_3 - ds.InstanceNumber + 1
        ds.save_as(os.path.join(result_dir, "t2/" + file))

    # convert to NIfTI
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(os.path.join(result_dir, "t2"))
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    sitk.WriteImage(image, os.path.join(data_dir, name, "ready_to_CaPTk_T2.nii.gz"))


def ready_for_CaPTk(patient_info, result_dir):
    """
    Function that takes input the DICOM returned by the hospital .brainlab hardware and return the four modalities.
    Specifically:
        - T1ncr and T1core present T1ce with manual segmentation for necrosis and core respectively
        - T1ce is the mri as given
        - T2 consists in FLAIR with the manual segmentation performed by the radiologist
        - FLAIR is the mri as given
    If necessary, mri reading direction are inverted to conform with BraTS2021 axial one.
    :param patient_info: pandas series with patient information
    :param result_dir: directory in which to save final remaining mri
    :return: name of the file corresponding to first image in DICOM series (t1, t1ce, t2, flair respectively)
    """
    name = patient_info.name
    outer = str(patient_info.loc["Outer"])
    # extract all .zip files
    for zipf in os.listdir(os.path.join(data_dir, name)):
        with zipfile.ZipFile(os.path.join(data_dir, name, zipf)) as zip_ref:
            if outer == "outer":
                zip_ref.extractall(os.path.join(data_dir, name, outer))
            else:
                zip_ref.extractall(os.path.join(data_dir, name))

    # harmonize the InstanceNumber (all series start at 1)
    for folder in os.listdir(os.path.join(data_dir, name, outer)):
        if (outer == "outer") and (int(os.listdir(os.path.join(data_dir, name, outer, folder))[0].split("-")[-2]) == 0):
            # series starts at 0
            for img in os.listdir(os.path.join(data_dir, name, outer, folder)):
                ds = dicom.dcmread(os.path.join(data_dir, name, outer, folder, img))
                ds.InstanceNumber = ds.InstanceNumber + 1
                ds.save_as(os.path.join(data_dir, name, outer, folder, img))

    # build t1
    build_t1ncr_t1core(patient_info=patient_info, result_dir=result_dir)
    # build t1ce, t2 and flair
    build_t1ce_t2_flair(patient_info=patient_info, result_dir=result_dir)

    # delete extracted files
    shutil.rmtree(os.path.join(data_dir, name, outer))
    if os.path.isfile(os.path.join(data_dir, name + "/DICOMDIR")):
        os.remove(os.path.join(data_dir, name + "/DICOMDIR"))


def CaPTk_preprocessing(patient):
    """
    Function that applies CaPTk preprocessing to all four modalities and saves each preprocessed DICOM as
    {modality}_to_SRI.nii.gz, where {modality} can be {FL, T1, T2, T1CE}. More in detail, the applied steps are:
        1. re-orientation to LPS/RAI
        2. image registration to SRI-24 atlas which includes the following steps:
                a. N4 Bias correction (This is a TEMPORARY STEP, and is not applied in the final co-registered output
                   images. It is only use to facilitate optimal registration.)
                b. rigid Registration of T1, T2, FLAIR to T1ce
                c. rigid Registration of T1CE to SRI-24 atlas
                d. applying transformation to the reoriented images
    :param patient: space stripped name of patient
    """
    os.system(
        f"CaPTk_full\\{CaPTk_version}\\bin\\BraTSPipeline.exe "
        f"-t1 {os.path.join(data_dir, patient, 'ready_to_CaPTk_T1ncr.nii.gz')} "
        f"-t1c {os.path.join(data_dir, patient, 'ready_to_CaPTk_T1CE.nii.gz')} "
        f"-t2 {os.path.join(data_dir, patient, 'ready_to_CaPTk_T2.nii.gz')} "
        f"-fl {os.path.join(data_dir, patient, 'ready_to_CaPTk_FLAIR.nii.gz')} "
        f"-o {os.path.join(data_dir, patient)} "
        f"-s 0 "  # do not skull strip
        f"-b 0 "  # do not segment brain tumors
        f"-d 0 "  # do not print debugging information
        f"-i 0"  # do not save intermediate files
    )
    os.rename(
        os.path.join(data_dir, patient, "T1_to_SRI.nii.gz"), os.path.join(data_dir, patient, "T1ncr_to_SRI.nii.gz")
    )

    os.system(
        f"CaPTk_full\\{CaPTk_version}\\bin\\BraTSPipeline.exe "
        f"-t1 {os.path.join(data_dir, patient, 'ready_to_CaPTk_T1core.nii.gz')} "
        f"-t1c {os.path.join(data_dir, patient, 'ready_to_CaPTk_T1CE.nii.gz')} "
        f"-t2 {os.path.join(data_dir, patient, 'ready_to_CaPTk_T2.nii.gz')} "
        f"-fl {os.path.join(data_dir, patient, 'ready_to_CaPTk_FLAIR.nii.gz')} "
        f"-o {os.path.join(data_dir, patient)} "
        f"-s 0 "  # do not skull strip
        f"-b 0 "  # do not segment brain tumors
        f"-d 0 "  # do not print debugging information
        f"-i 0"  # do not save intermediate files
    )
    os.rename(
        os.path.join(data_dir, patient, "T1_to_SRI.nii.gz"), os.path.join(data_dir, patient, "T1core_to_SRI.nii.gz")
    )

    # delete unnecessary files
    for el in os.listdir(os.path.join(data_dir, patient)):
        if os.path.isdir(os.path.join(data_dir, patient, el)):
            shutil.rmtree(os.path.join(data_dir, patient, el))
        elif (not el.endswith("SRI.nii.gz")) and (not el.endswith(".zip")):
            # keep only preprocessed and original files
            os.remove(os.path.join(data_dir, patient, el))


def t1ce_SynthStrip(patient, idx, abspath):
    """
    Function that perform SynthStrip skull stripping and save both output and computed mask.
    T1CE brain stripped output is saved in patient dir as BraTS2021_{10000+idx}_t1ce.nii.gz where idx is patient index.
    Brain strip mask is saved as BraTS2021_{10000+idx}_mask.nii.gz.
    :param patient: space stripped name of patient
    :param idx: patient index
    :param abspath: absolute path of the T1ce CaPTk preprocessed file
    """
    abspath = abspath.split(os.path.sep)
    disk = abspath[0][:-1].lower()
    middle = abspath[1]
    for el in abspath[2:-1]:
        middle += f"/{el}"
    end = abspath[-1]
    os.system(
        f'wsl ~ -e sh -c '
        f'"python3 synthstrip-docker -i ../../mnt/{disk}/{middle}/{data_dir[2:]}/{patient}/{end} '
        f'-o ../../mnt/{disk}/{middle}/{data_dir[2:]}/{patient}/BraTS2021_{10000 + idx}_t1ce.nii.gz '
        f'-m ../../mnt/{disk}/{middle}/{data_dir[2:]}/{patient}/BraTS2021_{10000 + idx}_mask.nii.gz"'
    )


def apply_mask_to_t1core_t1ncr_t2_flair(patient, idx):
    """
    Function that applies SynthStrip mask to T1, T2 and FLAIR modalities for skull stripping.
    Brain stripped outputs are saved in patient folder as BraTS2021_{10000+idx}_{modality}.nii.gz, where
    {modality} can be {t1, t2, flair}
    :param patient: space stripped name of patient
    :param idx: patient index
    """
    # retrieve NIfTI CaPTk preprocessed files
    mask = nib.load(os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_mask.nii.gz"))
    t1core = nib.load(os.path.join(data_dir, patient, f"T1core_to_SRI.nii.gz"))
    t1ncr = nib.load(os.path.join(data_dir, patient, f"T1ncr_to_SRI.nii.gz"))
    t2 = nib.load(os.path.join(data_dir, patient, f"T2_to_SRI.nii.gz"))
    flair = nib.load(os.path.join(data_dir, patient, f"FL_to_SRI.nii.gz"))

    # skull strip T1, T2 and FLAIR
    not_brain = mask.get_fdata() == 0
    t1core_brain = t1core.get_fdata()
    t1core_brain[not_brain] = 0.
    t1ncr_brain = t1ncr.get_fdata()
    t1ncr_brain[not_brain] = 0.
    t2_brain = t2.get_fdata()
    t2_brain[not_brain] = 0.
    flair_brain = flair.get_fdata()
    flair_brain[not_brain] = 0.

    # save skull-stripped NIfTI
    new_t1core = nib.Nifti1Image(t1core_brain.astype(np.float), t1core.affine)
    nib.save(new_t1core, os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t1core.nii.gz"))
    new_t1ncr = nib.Nifti1Image(t1ncr_brain.astype(np.float), t1ncr.affine)
    nib.save(new_t1ncr, os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_t1ncr.nii.gz"))
    new_t2 = nib.Nifti1Image(t2_brain.astype(np.float), t2.affine)
    nib.save(new_t2, os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t2.nii.gz"))
    new_flair = nib.Nifti1Image(flair_brain.astype(np.float), flair.affine)
    nib.save(new_flair, os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_flair.nii.gz"))

    # delete unnecessary files
    for file in os.listdir(os.path.join(data_dir, patient)):
        if (not file.endswith(".zip")) and (not file.startswith("BraTS2021")):
            # keep only BraTS-ready and original files
            os.remove(os.path.join(data_dir, patient, file))


def build_segmentation(patient, idx):
    """
    Function that joins the three labels into a _seg.nii.gz file as conforming with BraTS2021 rules.
    By default the values assigned to each class are the following:
        - core: 1
        - enhancing: 4
        - edema: 2
    :param patient: space stripped name of patient
    :param idx: patient index
    """
    t1ncr = nib.load(os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t1ncr.nii.gz"))
    highest = np.max(t1ncr.get_fdata())
    # retrieve ncr
    is_ncr = t1ncr.get_fdata() == highest
    t1core = nib.load(os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_t1core.nii.gz"))
    highest = np.max(t1core.get_fdata())
    # retrieve enh
    is_core = t1core.get_fdata() == highest
    is_enh = is_core.astype(np.float32) - is_ncr.astype(np.float)
    is_enh = is_enh > 0

    t2 = nib.load(os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t2.nii.gz"))
    highest = np.max(t2.get_fdata())
    # retrieve edema
    is_whole = t2.get_fdata() == highest
    is_edema = is_whole.astype(np.float32) - is_enh.astype(np.float32) - is_ncr.astype(np.float32)
    is_edema = is_edema > 0

    # join into single array and save new NIfTI
    joined = np.zeros(t1core.get_fdata().shape)
    joined[is_edema] = 2
    joined[is_enh] = 4
    joined[is_ncr] = 1
    seg = nib.Nifti1Image(joined.astype(np.float32), t1core.affine)
    nib.save(seg, os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_seg.nii.gz"))

    ncr = nib.Nifti1Image(is_ncr.astype(np.float32), t1core.affine)
    nib.save(ncr, os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_ncr.nii.gz"))
    enh = nib.Nifti1Image(is_enh.astype(np.float32), t1core.affine)
    nib.save(enh, os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_enh.nii.gz"))
    edema = nib.Nifti1Image(is_edema.astype(np.float32), t1core.affine)
    nib.save(edema, os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_edema.nii.gz"))

    # delete unnecessary files
    os.remove(os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t1core.nii.gz"))
    os.remove(os.path.join(data_dir, patient, f"BraTS2021_{10000 + idx}_t1ncr.nii.gz"))
    os.remove(os.path.join(data_dir, patient, f"BraTS2021_{10000+idx}_t2.nii.gz"))


if __name__ == "__main__":

    data_dir = "./pre-operative"
    CaPTk_version = "1.9.0"

    # retrieve required information regarding all patients
    info = pd.read_csv("./info.csv", index_col=0, header=0, comment="#")

    for idx, patient in enumerate(os.listdir(data_dir), 41):
        print(info.loc[patient])

        # build the four MRI modalities as a first step before CaPTk
        ready_for_CaPTk(patient_info=info.loc[patient], result_dir=os.path.join(data_dir, patient))

        # preprocess the patient for BraTS2021 compatibility (N4 bias correction, SRI24 template co-registration, etc.)
        CaPTk_preprocessing(patient=patient)

        # run skull-strip of T1CE and get mask
        t1ce_SynthStrip(patient=patient, idx=idx, abspath=os.path.abspath("T1CE_to_SRI.nii.gz"))

        # skull strip T1, T2 and FLAIR
        apply_mask_to_t1core_t1ncr_t2_flair(patient=patient, idx=idx)

        # build segmentation NIfTI
        build_segmentation(patient=patient, idx=idx)