import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

CUDA = torch.cuda.is_available()
if CUDA:
    import cupy as cp
import numpy as np 

import itertools
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, precision_recall_curve, auc

import metrics as mt

class SingleTest(object):
    def __init__(self, model, loss_fn, optimizer):
        self.CUDA=CUDA

        # Arguments
        self.model              = model.cuda() if self.CUDA else model
        self.loss_fn            = loss_fn
        self.optimizer          = optimizer
        self.device             = 'cuda' if self.CUDA else 'cpu'
        self.modelID            = None 
        
        # Placeholders
        self.train_loader       = None
        self.test_loader        = None

        # Variables
        self.losses             = []
        self.val_losses         = []
        self.accuracy           = []
        self.val_accuracy       = []
        self.f1Score            = []
        self.val_f1Score        = []
        self.MCC                = []
        self.val_MCC            = []
        self.total_epochs       = 0
        self.final_loss         = None
        self.final_val_loss     = None
        self.final_accuracy     = None
        self.final_f1Score      = None
        self.final_MCC          = None
        self.final_val_accuracy = None
        self.final_val_f1Score  = None
        self.final_val_MCC      = None
        self.seed               = 0

        # Functions
        self.train_step_fn      = self._make_train_step_fn()
        self.val_step_fn        = self._make_val_step_fn()

    def to(self, device):
        try:
            self.device = device
            self.model.to(self.device)
        except RuntimeError:
            self.device = ('cuda' if self.CUDA else 'cpu')
            print(f"Couldn't send it to {device}, sending it to {self.device}")
            self.model.to(self.device)

    def set_loaders(self, train_loader, val_loader = None):
        self.train_loader = train_loader
        self.val_loader = val_loader        

    def set_seed(self, seed = 42):
        torch.backends.cudnn.deterministic = True    
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(seed)
        if self.CUDA:
            cp.random.seed(seed)
        else:
            np.random.seed(seed)
        self.seed = seed

    def _make_train_step_fn(self):
        def perform_train_step_fn(x,y):
            self.model.train()    
            yhat = self.model(x)
            loss = self.loss_fn(yhat, y)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            return loss.item()
        return perform_train_step_fn    

    def _make_val_step_fn(self):
        def perform_val_step_fn(x,y):
            self.model.eval()
            yhat = self.model(x)
            loss = self.loss_fn(yhat, y)    
            return loss.item()
        return perform_val_step_fn

    def _mini_batch(self, validation = False):    
        if validation:
            data_loader = self.val_loader
            step_fn = self.val_step_fn
        else:
            data_loader = self.train_loader
            step_fn = self.train_step_fn
        if data_loader is None:
            return None    

        mini_batch_losses = []
        tp = 0      # True positives
        tn = 0      # True negatives
        fp = 0      # False positives
        fn = 0      # False negatives
        for x_batch, y_batch in data_loader:
            x_batch = x_batch.to(self.device)
            y_batch = y_batch.to(self.device)
            mini_batch_loss = step_fn(x_batch,y_batch)
            mini_batch_losses.append(mini_batch_loss)
            if str(self.loss_fn) == "BCEWithLogitsLoss()":  
                with torch.no_grad():
                    predictions = self.predict(x_batch)
                    if self.CUDA:
                        actual=cp.asarray(y_batch)
                        class_pred = (predictions>=0).astype(cp.int32)
                    else:
                        actual = y_batch.numpy()
                        class_pred = (predictions>=0).astype(np.int32)
                    for (act,pred) in zip(actual, class_pred):
                        if act==1:
                            if pred==1:
                                tp+=1
                            else:
                                fn+=1    
                        else:
                            if pred==1:
                                fp+=1
                            else:
                                tn+=1  
            else:
                print(f'Loss function not supported...')                    
                exit() 
        if self.CUDA:
            loss = cp.mean(cp.asarray(mini_batch_losses))
        else:
            loss = np.mean(mini_batch_losses) 
        return loss,mt.accuracy(tp,tn,fp,fn), mt.F1_Score(tp,tn,fp,fn), mt.MatthewsCoeff(tp,tn,fp,fn)


    def train(self,n_epochs, seed = 49):    
        self.set_seed(seed)
        for epoch in range(n_epochs):
            self.total_epochs += 1
            loss,acc,f1score,MCC = self._mini_batch(validation = False)
            self.losses.append(loss)
            self.accuracy.append(acc)
            self.f1Score.append(f1score)
            self.MCC.append(MCC)

            # Validation
            with torch.no_grad():
                val_loss,val_acc,val_f1score,val_MCC = self._mini_batch(validation=True)            
                self.val_losses.append(val_loss)
                self.val_accuracy.append(val_acc)
                self.val_f1Score.append(val_f1score)
                self.val_MCC.append(val_MCC)
        self.final_loss = self.losses[len(self.losses)-1]
        self.final_val_loss = self.val_losses[len(self.val_losses)-1] 
        self.final_accuracy = self.accuracy[len(self.accuracy)-1]
        self.final_val_accuracy = self.val_accuracy[len(self.val_accuracy)-1]
        self.final_f1Score = self.f1Score[len(self.f1Score)-1]
        self.final_val_f1Score = self.val_f1Score[len(self.val_f1Score)-1]
        self.final_MCC = self.MCC[len(self.MCC)-1]
        self.final_val_MCC = self.val_MCC[len(self.val_MCC)-1]
        
    def save_checkpoint(self, filename):
        checkpoint = {
                'epoch': self.total_epochs,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.losses,
                'val_loss': self.val_losses
        }            
        torch.save(checkpoint, filename)

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.total_epochs = checkpoint['epoch']
        self.losses = checkpoint['loss']
        self.val_losses = checkpoint['val_loss']
        self.model.train()

    def export_to_onnx(self, filename, features, labels):
        # Creating a Dummy input
        batch_size = 1
        num_features = self.model[0].in_features
        x = torch.randn(batch_size,num_features,requires_grad = True)
        # Creating dynamic axes
        dynamic_axes_labels = features + labels
        dynamic_axes_names = { key:{0:'batch_size'} for key in dynamic_axes_labels}
        self.model.eval()
        if self.CUDA:
            torch.onnx.export(self.model.cuda(),x.cuda(), filename,export_params=True,opset_version=10,do_constant_folding=True,input_names=features,output_names=labels,dynamic_axes=dynamic_axes_names)
        else:
            torch.onnx.export(self.model,x, filename,export_params=True,opset_version=10,do_constant_folding=True,input_names=features,output_names=labels,dynamic_axes=dynamic_axes_names)

    def predict(self, x):
        self.model.eval()    
        x_tensor = torch.as_tensor(x).float()
        y_hat_tensor = self.model(x_tensor.to(self.device))
        self.model.train()
        if self.CUDA:
            return y_hat_tensor.detach().cpu().numpy()
        else:
            return y_hat_tensor.detach().numpy()

    def plot_losses(self):
        fig = plt.figure(figsize = (10,4))
        if CUDA:
            self.losses = cp.asarray(self.losses)
            self.val_losses = cp.asarray(self.val_losses)
            plt.plot(self.losses.get(), label = 'Training Loss', c = 'b')
            if self.val_loader:
                plt.plot(self.val_losses.get(), label='Validation Loss', c='r')
        else:
            plt.plot(self.losses, label = 'Training Loss', c = 'b')
            if self.val_loader:
                plt.plot(self.val_losses, label='Validation Loss', c='r')
        plt.yscale('log')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        return fig    

    def plot_accuracy(self):
        fig = plt.figure(figsize=(10,4))
        if CUDA:
            self.accuracy = cp.asarray(self.accuracy)
            self.val_accuracy = cp.asarray(self.val_accuracy)
            plt.plot(self.accuracy.get(), label='Training Accuracy', c='b')
            if self.val_loader:
                plt.plot(self.val_accuracy.get(), label='Validation Accuracy', c='r')
        else:
            plt.plot(self.accuracy, label='Training Accuracy', c='b')
            if self.val_loader:
                plt.plot(self.val_accuracy, label='Validation Accuracy', c='r')
        plt.xlabel('Epochs')    
        plt.ylabel('Accuracy')
        plt.legend()
        plt.tight_layout()
        return fig

    def plot_f1_score(self):
        fig = plt.figure(figsize=(10,4))
        if CUDA:
            self.f1Score = cp.asarray(self.f1Score)
            self.val_f1Score = cp.asarray(self.val_f1Score)
            plt.plot(self.f1Score.get(), label='Training F1 Score', c='b')
            if self.val_loader:
                plt.plot(self.val_f1Score.get(), label= 'Validation F1 Score', c='r')
        else:
            plt.plot(self.f1Score, label='Training F1 Score', c='b')
            if self.val_loader:
                plt.plot(self.val_f1Score, label= 'Validation F1 Score', c='r')
        plt.xlabel('Epochs')
        plt.ylabel('F1 Score')
        plt.legend()
        plt.tight_layout()
        return fig

    def plot_MCC(self):
        fig = plt.figure(figsize=(10,4))
        if CUDA:
            self.MCC = cp.asarray(self.MCC)
            self.val_MCC = cp.asarray(self.val_MCC)
            plt.plot(self.MCC.get(), label='Training MCC', c='b')
            if self.val_loader:
                plt.plot(self.val_MCC.get(), label='Validation MCC',c='r')
        else:
            plt.plot(self.MCC, label='Training MCC', c='b')
            if self.val_loader:
                plt.plot(self.val_MCC, label='Validation MCC',c='r')
        plt.xlabel('Epochs')
        plt.ylabel('MCC')
        plt.legend()
        plt.tight_layout()    
        return fig

    def setModelID(self,ID):
        self.modelID = ID    

    def getModelID(self):
        return self.modelID    
