
import numpy as np
import torch
import monai.transforms as transforms
import json
import math
import os
import pickle
import nibabel

from monai.utils import InterpolateMode
from joblib import Parallel, delayed
from subprocess import run


# inspired by the NVIDIA nnU-Net GitHub repository available at:
# https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet

# preprocessing makes use of the MONAI toolkit available at:
# https://github.com/Project-MONAI/MONAI


# configs
task = {
    "train": "BraTS2021_train",
    "val": "BraTS2021_val",
}


class Preprocessor:
    def __init__(self, args):
        """
        Initialize the preprocessor.
        :param args: args
        """
        self.args = args
        self.target_spacing = [1.0, 1.0, 1.0]
        # get task (here either "train" or "val" for BraTS21 training and validating respectively)
        self.task = args.task
        # get task code
        self.task_code = f"{args.task}_{args.dim}d"
        self.verbose = args.verbose
        self.patch_size = [224, 224]
        # determine if training
        self.training = args.exec_mode == "training"
        # path of input data
        self.data_path = os.path.join(args.data, task[args.task])
        # retrieve metadata from previously created json
        metadata_path = os.path.join(self.data_path, "dataset.json")
        self.metadata = json.load(open(metadata_path, "r"))
        # retrieve modalities used
        self.modality = self.metadata["modality"]["0"]
        self.results = os.path.join(args.results, self.task_code)
        self.ct_min, self.ct_max, self.ct_mean, self.ct_std = (0,) * 4
        if not self.training:
            self.results = os.path.join(self.results, self.args.exec_mode)
        # normalize only non-zero region for MRI
        self.normalize_intensity = transforms.NormalizeIntensity(nonzero=False, channel_wise=True)
        # scale intensity to [-1, 1]
        self.scale_intensity = transforms.ScaleIntensity(minv=0.0, maxv=1.0, channel_wise=True)
        # resize to match patch size
        self.resize_modality = transforms.Resize(spatial_size=self.patch_size, mode=InterpolateMode.BICUBIC,
                                                 align_corners=True)

    def run(self):
        """
        Apply preprocessing step.
        """
        # make directory for results
        run(["rm", "-rf", self.results])
        os.makedirs(self.results)
        print(f"Preprocessing {self.data_path}")
        if self.verbose:
            print(f"Target spacing {self.target_spacing}")

        self.run_parallel(self.preprocess_pair, self.args.exec_mode)
        # create pickle with infos
        pickle.dump(
            {
                "patch_size": self.patch_size,
                "spacings": self.target_spacing,
                "in_channels": len(self.metadata["modality"]),
            },
            open(os.path.join(self.results, "config.pkl"), "wb"),
        )

    def preprocess_pair(self, pair):
        """
        Preprocess a pair (available, target) and save the result.
        :param pair: pair (available, target)
        """
        fname = os.path.basename(pair["available"] if isinstance(pair, dict) else pair)
        available, target, available_spacings = self.load_pair(pair)

        # Crop foreground and store original shapes
        orig_shape = available.shape[1:]
        bbox = transforms.utils.generate_spatial_bounding_box(available)
        # do not cut along slices dimension
        available = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(available)
        available_metadata = np.vstack([bbox, orig_shape, available.shape[1:]])
        if target is not None:
            self.save_npy(target, fname, "_orig_lbl.npy")
            target = transforms.SpatialCrop(roi_start=bbox[0], roi_end=bbox[1])(target)

        # resize to match patch_size (pad + resize)
        patch_temp = [np.max(available.shape[2:]), np.max(available.shape[2:])]
        pad = transforms.ResizeWithPadOrCrop(patch_temp)
        available = np.stack([self.resize(pad(available[idx])) for idx in range(available.shape[0])])
        if target is not None:
            target = self.resize(pad(target[0]))
            target = np.expand_dims(target, 0)

        # normalize intensity
        infos = np.array([np.mean(available[available != 0]), np.std(available[available != 0]), 1.])
        #available_metadata = np.vstack([available_metadata, infos])  # keep record of mean and std
        available = self.normalize(np.abs(available))
        if target is not None:
            target = self.normalize(np.abs(target))

        # scale intensities to [0, 1]
        infos = np.array([np.min(target), np.max(target), 1.])
        #available_metadata = np.vstack([available_metadata, infos]) # keep record of intensities
        available = self.scale(available, nonzero=False)
        if target is not None:
            target = self.scale(target, nonzero=False)

        self.save(available, target, fname, available_metadata)

    def normalize(self, modality):
        """
        Normalize the intensity of a given modality.
        :param modality: modality
        :return: intensity-normalized modality
        """

        return self.normalize_intensity(modality)

    def scale(self, modality, nonzero=True):
        """
        Scale the intensity of a given modality to the range [-1, 1].
        :param modality: modality
        :param nonzero: keep zero valued pixels fixed
        :return: intensity-scaled modality
        """
        scaled = self.scale_intensity(modality)
        if nonzero:
            # reset background at 0
            for idx in range(scaled.shape[0]):
                scaled[idx, scaled[idx] == scaled[idx, 0, 0, 0]] = 0.

        return scaled

    def resize(self, modality):
        """
        Resize the shape of the input modality to match self.patch_size.
        :param modality: modality
        :return: Resized modality
        """

        return self.resize_modality(modality)

    def save(self, available, target, fname, available_metadata):
        """
        Save available and target modalities with respective metadata.
        :param available: available modality
        :param target: target modality
        :param fname: file name
        :param available_metadata: available modality metadata
        """
        mean = np.round(np.mean(available, (1, 2, 3)), 2)
        std = np.round(np.std(available, (1, 2, 3)), 2)
        if self.verbose:
            print(f"Saving {fname} shape {available.shape} mean {mean} std {std}")
        # save scans as numpy files
        self.save_npy(available, fname, "_x.npy")
        if target is not None:
            self.save_npy(target, fname, "_y.npy")
        if available_metadata is not None:
            self.save_npy(available_metadata, fname, "_meta.npy")

    def load_pair(self, pair):
        """
        Load available, target and spacings modalities from previously saved NIfTI.
        :param pair: NIfTI file
        :return: available, target, available_spacing
        """
        available = self.load_nifti(pair["available"] if isinstance(pair, dict) else pair)
        # load spacing
        available_spacing = available.header["pixdim"][1:4].tolist()[::-1]
        available = available.get_fdata().astype(np.float32)
        # standardize layout in (C, 155, 240, 240)
        available = np.transpose(available, (3, 2, 1, 0))

        if self.training:
            target = self.load_nifti(pair["target"]).get_fdata().astype(np.float32)
            # standardize layout in (1, 155, 240, 240)
            target = np.expand_dims(target, 3)
            target = np.transpose(target, (3, 2, 1, 0))
        else:
            target = None

        return available, target, available_spacing

    def save_npy(self, modality, fname, suffix):
        """
        Save the result modality as numpy file.
        :param modality: modality
        :param fname: file name
        :param suffix: file suffix
        """

        np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), modality, allow_pickle=False)

    def run_parallel(self, func, exec_mode):
        """
        Run parallelized jobs.
        :param func: function
        :param exec_mode: execution mode
        :return: joblib.Parallel() call
        """

        return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])

    def load_nifti(self, fname):
        """
        Load a NIfTI file.
        :param fname: filename
        :return: niblib.load() call
        """
        return nibabel.load(os.path.join(self.data_path, fname))
