
import numpy as np
import glob
import os

from pytorch_lightning import LightningDataModule
from sklearn.model_selection import KFold
from utils.utils import get_config_file, get_task_code, print0

from data_loading.dali_loader import fetch_dali_loader


# DataModule makes use of the NVIDIA Data Loading Library (DALI)
# read more at: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/


class DataModule(LightningDataModule):
    def __init__(self, args):
        """
        Initialize the LightningDataModule.
        A datamodule encapsulates the five steps involved in data processing in PyTorch:
            download / tokenize / process
            clean and (maybe) save to disk
            load inside Dataset
            apply transforms (rotate, tokenize, etc…)
            wrap inside a DataLoader
        :param args: args
        """
        super().__init__()
        self.args = args
        self.data_path = get_data_path(args)
        self.kfold = get_kfold_splitter(args.nfolds)
        self.kwargs = {
            "dim": self.args.dim,
            "seed": self.args.seed,
            "gpus": self.args.gpus,
            "num_workers": self.args.num_workers,
            "patch_size": get_config_file(self.args)["patch_size"],
            "in_channels": get_config_file(self.args)["in_channels"],
        }
        self.train_images, self.train_labels, self.val_images, self.val_labels, self.test_images = ([],) * 5

    def setup(self, stage=None):
        """
        Create the starting datasets splitting between training and validation.
        :param stage: used to separate setup logic. Here initialized to None since all stages have been set-up.
        """
        meta = load_data(self.data_path, "*_meta.npy")
        images = load_data(self.data_path, "*_x.npy")
        self.test_images, test_meta = get_test_fnames(self.args, self.data_path, meta)

        if self.args.exec_mode != "predict":
            orig_lbl = load_data(self.data_path, "*_orig_lbl.npy")
            labels = load_data(self.data_path, "*_y.npy")
            train_idx, val_idx = list(self.kfold.split(images))[self.args.fold]
            orig_lbl, meta = get_split(orig_lbl, val_idx), get_split(meta, val_idx)
            self.kwargs.update({"orig_lbl": orig_lbl, "meta": meta})
            self.train_images, self.train_labels = get_split(images, train_idx), get_split(labels, train_idx)
            self.val_images, self.val_labels = get_split(images, val_idx), get_split(labels, val_idx)
        else:
            # prediction only
            self.kwargs.update({"meta": test_meta})
        print0(f"{len(self.train_images)} training, {len(self.val_images)} validation, {len(self.test_images)} test examples")

    def train_dataloader(self):
        """
        Fetch the train DALI data loader.
        :return: train DALI data loader
        """
        return fetch_dali_loader(self.train_images, self.train_labels, 1, "train", **self.kwargs)

    def val_dataloader(self):
        """
        Fetch the eval DALI data loader.
        :return: eval DALI data loader
        """
        return fetch_dali_loader(self.val_images, self.val_labels, 1, "eval", **self.kwargs)

    def test_dataloader(self):
        """
        Fetch the test DALI data loader.
        :return: test DALI data loader
        """
        return fetch_dali_loader(self.test_images, None, 1, "test", **self.kwargs)


def get_split(data, idx):
    """
    Retrieve data split for a given set of indices idx.
    :param data: data
    :param idx: set of indices
    :return: list with corresponding split
    """
    return list(np.array(data)[idx])


def load_data(path, files_pattern, non_empty=True):
    """
    Retrieve all filenames including a given files_pattern from a path.
    :param path: path
    :param files_pattern: recurrent files pattern
    :param non_empty: boolean value to assert whether desired files exist or not
    :return: list of filenames
    """
    data = sorted(glob.glob(os.path.join(path, files_pattern)))
    if non_empty:
        assert len(data) > 0, f"No data found in {path} with pattern {files_pattern}"

    return data


def get_kfold_splitter(nfolds):
    """
    Retrieve the sklearn.model_selection.KFold for splitting data.
    :param nfolds: desired number of folds
    :return: sklearn.model_selection.KFold()
    """
    return KFold(n_splits=nfolds, shuffle=True, random_state=12345)


def get_test_fnames(args, data_path, meta=None):
    """
    Retrieve all filenames including a given files_pattern from a path.
    :param data_path: path
    :param meta: metadata
    :return: list of test filenames with metadata, if given
    """
    test_images = load_data(data_path, "*_x.npy", non_empty=False)

    return test_images, meta


def get_data_path(args):
    """
    Retrieve data path.
    :param args: args
    :return: data path
    """
    if args.data != "./data":
        return args.data

    data_path = os.path.join(args.data, get_task_code(args))
    if args.exec_mode == "predict":
        data_path = os.path.join(data_path, "test")

    return data_path
