
import numpy as np
import os
import nvidia.dali.fn as fn
import nvidia.dali.ops as ops
import nvidia.dali.types as types

from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator


# data augmentation makes use of the NVIDIA Data Loading Library (DALI)
# read more at: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/


def random_augmentation(probability, augmented, original):
    """
    Perform random augmentation with a given probability.
    :param probability: probability of performing random augmentation
    :param augmented: augmented image
    :param original: original image
    :return: either one or the other
    """
    condition = fn.cast(fn.random.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
    neg_condition = condition ^ True

    return condition * augmented + neg_condition * original


class GenericPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, **kwargs):
        """
        Initialize the generic data loading pipeline.
        :param batch_size: batch size
        :param num_threads: number of threads
        :param device_id: device id
        :param kwargs: kwargs
        """
        super().__init__(batch_size, num_threads, device_id)
        self.kwargs = kwargs
        self.dim = kwargs["dim"]
        self.device = device_id
        self.patch_size = [kwargs["in_channels"]] + kwargs["patch_size"]
        self.load_to_gpu = kwargs["load_to_gpu"]
        self.input_x = self.get_reader(kwargs["imgs"])
        self.input_y = self.get_reader(kwargs["lbls"]) if kwargs["lbls"] is not None else None

    def get_reader(self, data):
        """
        Retrieve the reader.
        :param data: list with data paths
        :return: nvidia.dali.ops.readers.Numpy()
        """
        return ops.readers.Numpy(
            files=data,
            device="cpu",
            read_ahead=True,
            dont_use_mmap=True,
            pad_last_batch=True,
            shard_id=self.device,
            seed=self.kwargs["seed"],
            num_shards=self.kwargs["gpus"],
            shuffle_after_epoch=self.kwargs["shuffle"],
        )

    def load_data(self):
        """
        Load data (eventually to gpu).
        :return: pair (image, label) if labels are present, otherwise image only.
        """
        image = self.input_x(name="ReaderX")
        if self.load_to_gpu:
            image = image.gpu()
        image = fn.reshape(image, layout="CDHW")  # layout standardization, if needed

        if self.input_y is not None:
            # a reader was set for labels as well
            label = self.input_y(name="ReaderY")
            if self.load_to_gpu:
                label = label.gpu()
            label = fn.reshape(label, layout="CDHW")

            return image, label

        return image

    def transpose_fn(self, image, label):
        """
        Transpose input pair (image, label).
        :param image: image
        :param label: label
        :return: transposed pair
        """
        image = fn.transpose(image, perm=(1, 0, 2, 3), transpose_layout=False)
        label = fn.transpose(label, perm=(1, 0, 2, 3), transpose_layout=False)

        return image, label


class TrainPipeline(GenericPipeline):
    def __init__(self, batch_size, num_threads, device_id, **kwargs):
        """
        Initialize the training-specific data loading pipeline.
        :param batch_size: batch size
        :param num_threads: number of threads
        :param device_id: device id
        :param kwargs: kwargs
        """
        super().__init__(batch_size, num_threads, device_id, **kwargs)
        self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
        self.crop_shape_float_available = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
        self.crop_shape_float_target = types.Constant(np.array([1] + self.patch_size[1:]), dtype=types.FLOAT)

    @staticmethod
    def slice_fn(image):
        """
        Slice the given image.
        :param image: image
        :return: sliced image
        """
        return fn.slice(image, 1, 3, axes=[0])

    def resize(self, data, crop_shape, interp_type):
        """
        Resize the given data.
        :param data: data
        :param interp_type: interpolation method
        :return: resized image
        """
        return fn.resize(data, interp_type=interp_type, size=crop_shape)

    def flips_fn(self, image, label):
        """
        Flip input pair volume, independently for each axis, with probability of 0.15.
        :param image: image
        :param label: label
        :return: eventually flipped pair (image, label)
        """
        # select randomly if flipping for the given image is required
        kwargs = {
            "horizontal": fn.random.coin_flip(probability=0.15),
            "vertical": fn.random.coin_flip(probability=0.15),
            "depthwise": fn.random.coin_flip(probability=0.15),
        }

        return fn.flip(image, **kwargs), fn.flip(label, **kwargs)

    def define_graph(self):
        """
        Define the whole data loading pipeline.
        :return: end-of-pipeline pair (image, label)
        """
        img, lbl = self.load_data()
        img, lbl = self.flips_fn(img, lbl)
        img, lbl = self.transpose_fn(img, lbl)

        return img, lbl


class EvalPipeline(GenericPipeline):
    def __init__(self, batch_size, num_threads, device_id, **kwargs):
        """
        Initialize the evaluation-specific data loading pipeline.
        :param batch_size: batch size
        :param num_threads: number of threads
        :param device_id: device id
        :param kwargs: kwargs
        """
        super().__init__(batch_size, num_threads, device_id, **kwargs)

    def define_graph(self):
        """
        Define the whole data loading pipeline.
        :return: retrieved pair (image, label)
        """
        image, label = self.load_data()
        image, label = self.transpose_fn(image, label)

        return image, label


class TestPipeline(GenericPipeline):
    def __init__(self, batch_size, num_threads, device_id, **kwargs):
        """
        Initialize the testing-specific data loading pipeline.
        :param batch_size: batch size
        :param num_threads: number of threads
        :param device_id: device id
        :param kwargs: kwargs
        """
        super().__init__(batch_size, num_threads, device_id, **kwargs)
        # labels are not present -> retrieve metadata instead
        self.input_meta = self.get_reader(kwargs["meta"])

    def define_graph(self):
        """
        Define the whole data loading pipeline.
        :return: retrieved pair (image, metadata)
        """
        image = self.load_data()
        image = fn.transpose(image, perm=(1, 0, 2, 3), transpose_layout=False)
        meta = self.input_meta(name="ReaderM")

        return image, meta


PIPELINES = {
    "train": TrainPipeline,
    "eval": EvalPipeline,
    "test": TestPipeline,
}


class LightningWrapper(DALIGenericIterator):
    def __init__(self, pipe, **kwargs):
        """
        Initialize the DALI iterator for classification tasks for PyTorch.
        It returns 2 outputs (image and label) in the form of PyTorch’s Tensor.
        :param pipe: list of pipelines to use
        :param kwargs: kwargs
        """
        super().__init__(pipe, **kwargs)

    def __next__(self):
        """
        Retrieve next pair.
        :return: next pair
        """
        out = super().__next__()[0]

        return out


def fetch_dali_loader(images, labels, batch_size, mode, **kwargs):
    """
    Fetch the DALI iterator loaded with desired pipelines.
    :param images: images
    :param labels: labels
    :param batch_size: batch size
    :param mode: desired pipeline. Choose one between ["train", "eval", "test]
    :param kwargs: kwargs
    :return: loaded DALI iterator
    """
    assert len(images) > 0, "Empty list of images!"
    if labels is not None:
        assert len(images) == len(labels), f"Number of images ({len(images)}) " \
                                           f"not matching number of labels ({len(labels)})"

    pipeline = PIPELINES[mode]  # retrieve the desired pipeline
    shuffle = True if mode == "train" else False
    dynamic_shape = True if mode in ["eval", "test"] else False
    load_to_gpu = True if mode in ["eval", "test", "benchmark"] else False
    pipe_kwargs = {"imgs": images, "lbls": labels, "load_to_gpu": load_to_gpu, "shuffle": shuffle, **kwargs}
    output_map = ["available", "meta"] if mode == "test" else ["available", "target"]
    rank = int(os.getenv("LOCAL_RANK", "0"))  # device id
    # define the pipeline
    pipe = pipeline(batch_size, kwargs["num_workers"], rank, **pipe_kwargs)

    return LightningWrapper(
        pipe,
        auto_reset=True,
        reader_name="ReaderX",
        output_map=output_map,
        dynamic_shape=dynamic_shape,
    )
