import os
import torch
import numpy as np
import pytorch_lightning as pl
import monai.transforms as transforms

from apex.optimizers import FusedAdam
from pytorch_lightning.utilities import rank_zero_only
from torchmetrics import PeakSignalNoiseRatio, MeanAbsoluteError
from monai.networks.nets import BasicUNet
from monai.networks.nets.basic_unet import TwoConv, Down, UpCat
from monai.networks.blocks.convolutions import Convolution
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.utils import InterpolateMode

from data_loading.data_module import get_data_path, get_test_fnames
from generator.metrics import SSIM
from utils.utils import get_config_file
from utils.logger import DLLogger


def weight_init(layer):
    """
    Apply He initialization technique for convolutions.
    :param layer: model layer
    """
    if isinstance(layer, torch.nn.Conv2d):
        torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
    elif isinstance(layer, Convolution):
        torch.nn.init.kaiming_uniform_(layer.conv.weight, nonlinearity="relu")
    elif isinstance(layer, TwoConv):
        torch.nn.init.kaiming_uniform_(layer.conv_0.conv.weight, nonlinearity="relu")
        torch.nn.init.kaiming_uniform_(layer.conv_1.conv.weight, nonlinearity="relu")
    elif isinstance(layer, Down) or isinstance(layer, UpCat):
        torch.nn.init.kaiming_uniform_(layer.convs.conv_0.conv.weight, nonlinearity="relu")
        torch.nn.init.kaiming_uniform_(layer.convs.conv_1.conv.weight, nonlinearity="relu")


class UGen(pl.LightningModule):
    def __init__(self, args):
        super(UGen, self).__init__()
        self.save_hyperparameters()
        self.args = args
        self.learning_rate = self.args.learning_rate
        self.test_idx = 0
        self.train_loss = []
        self.model = self.build_generator()
        self.model.apply(weight_init)  # initialize convolutions weights
        self.loss = torch.nn.MSELoss(reduction="mean")
        self.patch_size = get_config_file(self.args)["patch_size"]
        self.mae = MeanAbsoluteError()
        self.best_temp_mae = float("inf")
        self.psnr = PeakSignalNoiseRatio(reduction='elementwise_mean')
        self.best_temp_psnr = 0.
        self.ssim = SSIM(spatial_dims=3)
        self.best_temp_ssim = 0.
        self.train_loss = []
        if self.args.exec_mode == "train":
            self.dllogger = DLLogger(args.results, f"fold{args.fold}_{args.logname}")

    def build_generator(self):
        """
        Build the basic UNet structure.
        :param args: args
        :return: MONAI BasicUNet
        """
        unet = BasicUNet(
            spatial_dims=2,
            in_channels=get_config_file(self.args)["in_channels"],
            out_channels=1,
            features=self.args.filters,
            act="ReLU",
            norm=None,
            dropout=0.,
            upsample="nontrainable",
        )
        unet.add_module("Tanh", torch.nn.Tanh())

        # add dropout where necessary and set Upsample with "nearest" interpolation
        dropouts = [0.1, 0.1, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1, 0.1]
        idx = 0
        for child in unet.children():
            if isinstance(child, TwoConv):
                child.conv_0.adn.D.p = dropouts[idx]
                idx += 1
            elif isinstance(child, Down):
                child.convs.conv_0.adn.D.p = dropouts[idx]
                idx += 1
            elif isinstance(child, UpCat):
                child.upsample.upsample_non_trainable.mode = "nearest"
                child.upsample.upsample_non_trainable.align_corners = None
                child.convs.conv_0.adn.D.p = dropouts[idx]
                idx += 1

        return unet

    def training_step(self, batch, batch_idx):
        """
        Perform the training step once batch is received.
        :param batch: batch of (inputs, outputs)
        :param batch_idx: index of the batch
        :return: computed loss
        """
        if batch_idx == 0:
            self.train_loss = []
        available = batch["available"][0]
        target = batch["target"][0]
        # random extraction
        rand_idx = np.random.randint(available.shape[0], size=64)
        available = available[rand_idx]
        target = target[rand_idx]
        generated = self.model(available)
        loss = self.loss(generated, target)
        self.train_loss.append(loss.item())

        return loss

    def validation_step(self, batch, batch_idx):
        """
        Perform the validation step once batch is received.
        :param batch: batch of (inputs, outputs)
        :param batch_idx: index of the batch
        """
        available = batch["available"][0]
        target = batch["target"][0]
        generated = self.model(available)
        loss = self.loss(generated, target)
        # update metrics
        self.mae.update(generated, target)
        self.psnr.update(generated, target)
        self.ssim.update(generated, target, loss)

    def test_step(self, batch, batch_idx):
        """
        Define the test step, eventually saving outputs.
        :param batch: batch
        :param batch_idx: batch index
        """
        available = batch["available"][0]
        prediction = torch.transpose(self.model(available), 0, 1).squeeze(0).cpu().detach().numpy()
        if self.args.save_preds:
            # resize to original shape and save as .npy file
            meta = batch["meta"][0].cpu().detach().numpy()
            min_d, max_d = meta[0, 0], meta[1, 0]
            min_h, max_h = meta[0, 1], meta[1, 1]
            min_w, max_w = meta[0, 2], meta[1, 2]
            patch_dim = max(max_h - min_h, max_w - min_w)
            prediction = transforms.Resize([patch_dim, patch_dim], mode=InterpolateMode.BICUBIC,
                                           align_corners=True, anti_aliasing=True)(prediction)
            prediction = transforms.ResizeWithPadOrCrop([max_h - min_h, max_w - min_w])(prediction)
            n_class, original_shape, cropped_shape = prediction.shape[0], meta[2], meta[3]
            final_pred = np.zeros(original_shape)
            final_pred[min_d:max_d, min_h:max_h, min_w:max_w] = prediction

            self.save_mask(final_pred)

    def validation_epoch_end(self, outputs):
        """
        Define the validation-end step.
        :param outputs: outputs
        """
        # compute metrics
        mae = self.mae.compute()
        self.mae.reset()
        psnr = self.psnr.compute()
        self.psnr.reset()
        ssim, loss = self.ssim.compute()
        self.floss = loss
        self.ssim.reset()

        mae_mean = torch.mean(mae)
        # update best metrics
        if mae_mean <= self.best_temp_mae:
            self.best_temp_mae = mae_mean
        psnr_mean = torch.mean(psnr)
        if psnr_mean >= self.best_temp_psnr:
            self.best_temp_psnr = psnr_mean
            self.best_epoch = self.current_epoch
        ssim_mean = torch.mean(ssim)
        if ssim_mean >= self.best_temp_ssim:
            self.best_temp_ssim = ssim_mean

        metrics = {}
        metrics["MAE"] = round(torch.mean(mae).item(), 4)
        metrics["Min MAE"] = round(torch.mean(self.best_temp_mae).item(), 4)
        metrics["PSNR"] = round(torch.mean(psnr).item(), 2)
        metrics["Max PSNR"] = round(torch.mean(self.best_temp_psnr).item(), 2)
        metrics["SSIM"] = round(torch.mean(ssim).item(), 2)
        metrics["Max SSIM"] = round(torch.mean(self.best_temp_ssim).item(), 2)
        metrics["Best epoch"] = self.best_epoch
        metrics["Train Loss"] = round(sum(self.train_loss) / len(self.train_loss), 4)
        metrics["Val Loss"] = round(torch.mean(loss).item(), 4)

        self.dllogger.log_metrics(step=self.current_epoch, metrics=metrics)
        self.dllogger.flush()
        if self.args.tb_logs:
            self.logger.log_metrics(metrics, step=self.current_epoch)
        self.log("mae", metrics["MAE"])
        self.log("psnr", metrics["PSNR"])
        self.log("ssim", metrics["SSIM"])

    @rank_zero_only
    def on_fit_end(self):
        """
        Define the fit-end step. Log metrics and flush.
        """
        metrics = {}
        metrics["mae_score"] = round(self.best_temp_mae.item(), 4)
        metrics["psnr_score"] = round(self.best_temp_psnr.item(), 2)
        metrics["ssim_score"] = round(self.best_temp_ssim.item(), 2)
        metrics["train_loss"] = round(sum(self.train_loss) / len(self.train_loss), 4)
        metrics["Epoch"] = self.best_epoch

        self.dllogger.log_metrics(step=(), metrics=metrics)
        self.dllogger.flush()

    def configure_optimizers(self):
        """
        Configure the Adam optimizer.
        """
        optimizer = FusedAdam(self.parameters(), lr=self.learning_rate, weight_decay=1e-6)

        if self.args.scheduler:
            # apply warmup cosine scheduler
            scheduler = {
                "scheduler": WarmupCosineSchedule(
                    optimizer=optimizer,
                    warmup_steps=250,
                    t_total=self.args.epochs * len(self.trainer.datamodule.train_dataloader()),
                ),
                "interval": "step",
                "frequency": 1,
            }
            return {"optimizer": optimizer, "monitor": "val_loss", "lr_scheduler": scheduler}

        return {"optimizer": optimizer, "monitor": "val_loss"}

    # workaround to avoid warning regarding learning rate and optimizer updates <- due to AMP
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **kwargs):
        self.should_skip_lr_scheduler_step = False
        scaler = getattr(self.trainer.strategy.precision_plugin, "scaler", None)
        if scaler:
            scale_before_step = scaler.get_scale()
        optimizer.step(closure=optimizer_closure)
        if scaler:
            scale_after_step = scaler.get_scale()
            self.should_skip_lr_scheduler_step = scale_before_step > scale_after_step

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        if self.should_skip_lr_scheduler_step:
            return
        scheduler.step()

    def save_mask(self, prediction):
        """
        Save the mask output after testing as numpy array.
        :param prediction: precition
        """
        if self.test_idx == 0:
            data_path = get_data_path(self.args)
            self.test_imgs, _ = get_test_fnames(self.args, data_path)

        fname = os.path.basename(self.test_imgs[self.test_idx]).replace("_x", "")
        np.save(os.path.join(self.save_dir, fname), prediction, allow_pickle=False)
        self.test_idx += 1
