import numbers
import six

import numpy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.collections
from matplotlib import pyplot
import seaborn as sns

from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import StratifiedKFold
from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import SMOTENC
import itertools

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

# using example from
# http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb

class Classifier():
    def __init__(self, model_name, repeated=False, smote=False, n_splits_cv=5, n_repeats=3):
        if repeated:
            self.cv = RepeatedStratifiedKFold(n_splits=n_splits_cv, n_repeats=n_repeats, random_state=1)
        else:
            self.cv = StratifiedKFold(n_splits=n_splits_cv)
        
        if model_name == 'lr':
            self.clf = LogisticRegression()
        elif model_name == 'svm':
            self.clf = SVC(probability=True)
        elif model_name == 'rf':
            self.clf = RandomForestClassifier()
        elif model_name == 'knn':
            self.clf = KNeighborsClassifier()
        else:
            raise Exception("Model not accepted, print: lr, svm or rf")
            
        self.model_name = model_name
        self.n_splits_cv=n_splits_cv  
        self.smote = smote
        self.best = None
        
    def evaluate_and_test(self, X_train, y_train, X_test, y_test, params_grid, metric='f1_macro',categories=None):
        if self.smote:
            self.cv = StratifiedKFold(n_splits=self.n_splits_cv)
            if self.model_name == 'lr':
                best_clf = build_classifier_SMOTE(X_train, y_train, X_test, y_test, LogisticRegression, params_grid, cv=self.cv, metric=metric, categories=categories)
            elif self.model_name == 'svm':
                params_grid['probability'] = [True]
                best_clf = build_classifier_SMOTE(X_train, y_train, X_test, y_test, SVC, params_grid, cv=self.cv, metric=metric, categories=categories)
            elif self.model_name == 'rf':
                best_clf = build_classifier_SMOTE(X_train, y_train, X_test, y_test, RandomForestClassifier, params_grid, cv=self.cv, metric=metric, categories=categories)
            elif self.model_name == 'knn':
                best_clf = build_classifier_SMOTE(X_train, y_train, X_test, y_test, KNeighborsClassifier, params_grid, cv=self.cv, metric=metric, categories=categories)
        else:
            best_clf= build_classifier(X_train, y_train, X_test, y_test, self.clf, params_grid, cv=self.cv, score=metric)
        self.best = best_clf
            
    def predict(self, X_test, threshold=0.5):
        if self.best == None:
            raise Exception("The model is not defined yet")
            
        scores = self.best.predict_proba(X_test)
        scores = scores[:,1]
        temp = scores > threshold
        return scores, temp.astype(int)
    
    def metrics_on_test(self, X_test, y_test, threshold=0.5):
        if self.best == None:
            raise Exception("The model is not defined yet")
        
        y_true = y_test
        y_pred = self.predict(X_test, threshold)
        accuracy = accuracy_score(y_true, y_pred)
        recall = recall_score(y_true, y_pred, average=None) # recall for each class
        precision = precision_score(y_true, y_pred, average=None) # precision for each class
        f1 = f1_score(y_true, y_pred, average=None) # f1 score for each class
        f1_macro = f1_score(y_true, y_pred, average='macro')
    
        print(f'Accuracy: {accuracy}')
        print(f'Recall class 0 (TNR): {recall[0]}')
        print(f'Recall class 1 (TPR): {recall[1]}')
        print(f'Precision class 0: {precision[0]}')
        print(f'Precision class 1: {precision[1]}')
        print(f'F1 class 0: {f1[0]}')
        print(f'F1 class 1: {f1[1]}')
        print(f'F1_macro: {f1_macro}')
        print()
        print(f'False negative rate: {1 - recall[1]}')
        print(f'False positive rate: {1 - recall[0]}')

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(5,5))
        sns.heatmap(cm, annot=True, fmt="d")
        plt.title('Confusion matrix @{:.2f}'.format(threshold))
        plt.ylabel('Actual label')
        plt.xlabel('Predicted label')
        

# this method is used inside build_classifier_SMOTE
def score_model(X_train, y_train, model, params, cv, metric, categories):
    """
    Creates folds manually, and upsamples within each fold.
    Returns an array of validation (recall) scores
    """
    if categories is None:
        smoter = SMOTE(random_state=42)
    else:
        smoter = SMOTENC(random_state=42, categorical_features=categories)  # categories is a list of indeces of the categorical columns
        
    scores = []

    for train_fold_index, val_fold_index in cv.split(X_train, y_train):
        # Get the training data
        X_train_fold, y_train_fold = X_train[train_fold_index], y_train[train_fold_index]
        # Get the validation data
        X_val_fold, y_val_fold = X_train[val_fold_index], y_train[val_fold_index]

        # Upsample only the data in the training section
        X_train_fold_upsample, y_train_fold_upsample = smoter.fit_resample(X_train_fold,
                                                                           y_train_fold)
        # Fit the model on the upsampled training data
        model_obj = model(**params).fit(X_train_fold_upsample, y_train_fold_upsample)
        # Score the model on the (non-upsampled) validation data
        if metric == 'accuracy':
            score = accuracy_score(y_val_fold, model_obj.predict(X_val_fold))
        if metric == 'recall':
            score = recall_score(y_val_fold, model_obj.predict(X_val_fold))
        if metric == 'precision':
            score = precision_score(y_val_fold, model_obj.predict(X_val_fold))
        if metric == 'f1_macro':
            score = f1_score(y_val_fold, model_obj.predict(X_val_fold), average='macro')
        scores.append(score)
    return np.array(scores).mean()

def build_classifier_SMOTE(X_train, y_train, X_test, y_test, model, params_grid, cv, metric, categories):
    items = list(params_grid.items()) # list of tuples
    list_of_lists = [item[1] for item in items]
    combinations = list(itertools.product(*list_of_lists))

    params_score = []
    for i in range(len(combinations)):
        params = {}
        for k,v in zip(params_grid.keys(), combinations[i]):
            params[k] = v
        cv_score = score_model(X_train, y_train, model, params, cv, metric, categories)
        params_score.append((cv_score, params))

    best_params = sorted(params_score, key=lambda x: x[0], reverse=True)[0][1]
    print(f'Best params configuration: {best_params}')
    print()

    smoter = SMOTE(random_state=42)
    X_train_upsample, y_train_upsample = smoter.fit_resample(X_train, y_train)
    clf = model(**best_params).fit(X_train_upsample, y_train_upsample)

    y_true, y_pred = y_test, clf.predict(X_test) 
    perf_measure(y_true, y_pred)
    plot_cm(y_true, y_pred)
    plot_roc_curve(clf, X_test, y_test)
    return clf

def build_classifier(X_train, y_train, X_test, y_test, clf_to_evaluate, param_grid, cv, score):
    print("# Tuning hyper-parameters for %s" % score)
    print()

    # If CV is an object, it usually is RepeatedStratifiedKFold, KFold, StratifiedKfold (default), ecc..
    clf = GridSearchCV(clf_to_evaluate, param_grid, cv=cv,
                       scoring=score, verbose=True, n_jobs=4)  #n_jobs --> practically a copy of the training set for each combination of hyperparameters in the grid
    clf.fit(X_train, y_train)

    print("Best parameters set found on development set:")
    print()
    print(clf.best_params_)
    print()
    print("Grid scores on development set:")
    print()
    means = clf.cv_results_['mean_test_score']
    stds = clf.cv_results_['std_test_score']
    for mean, std, params in zip(means, stds, clf.cv_results_['params']):
        print("%0.3f (+/-%0.03f) for %r"
              % (mean, std * 2, params))
    print()

    print("Detailed classification report:")
    print()
    print("The model is trained on the full development set.")
    print("The scores are computed on the full test set.")
    print()
    y_true, y_pred = y_test, clf.predict(X_test)
    perf_measure(y_true, y_pred)
    plot_cm(y_true, y_pred)
    plot_roc_curve(clf, X_test, y_test)
    return clf.best_estimator_
    #print(classification_report(y_true, y_pred))
    #print(confusion_matrix(y_true, y_pred))


def perf_measure(y_true, y_pred, scores=None):
    accuracy = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred, average=None) # recall for each class
    precision = precision_score(y_true, y_pred, average=None) # precision for each class
    f1 = f1_score(y_true, y_pred, average=None) # f1 score for each class
    f1_macro = f1_score(y_true, y_pred, average='macro')
    
   

    print(f'Accuracy: {accuracy}')
    print(f'Recall class 0 (TNR): {recall[0]}')
    print(f'Recall class 1 (TPR): {recall[1]}')
    print(f'Precision class 0: {precision[0]}')
    print(f'Precision class 1: {precision[1]}')
    print(f'F1 class 0: {f1[0]}')
    print(f'F1 class 1: {f1[1]}')
    print(f'F1_macro: {f1_macro}')
    print()
    print(f'False negative rate: {1 - recall[1]}')
    print(f'False positive rate: {1 - recall[0]}')
    
    if scores is None:
        return
    
    AUC = roc_auc_score(y_true, scores)
    print(f'AUC score: {AUC}')
    return f1_macro, AUC

# Confusion matrix
def plot_cm(labels, predictions, p=0.5):
      #cm = confusion_matrix(labels, predictions > p)
      cm = confusion_matrix(labels, predictions)
      plt.figure(figsize=(5,5))
      sns.heatmap(cm, annot=True, fmt="d")
      plt.title('Confusion matrix @{:.2f}'.format(p))
      plt.ylabel('Actual label')
      plt.xlabel('Predicted label')


def plot_roc_curve(clf, X, y_true):
     scores = clf.predict_proba(X)
     positive_score = scores[:, 1]
     fpr, tpr, thresholds = roc_curve(y_true, positive_score)
     plot_roc(tpr, fpr, thresholds, label_every=250) 
     print(f'AUC score: {roc_auc_score(y_true, positive_score)}')
 
  
def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates,
    in the correct format for LineCollection:
    an array of the form
    numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments


def colorline(x, y, z=None, axes=None,
              cmap=pyplot.get_cmap('coolwarm'),
              norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
              **kwargs):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if isinstance(z, numbers.Real):
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = matplotlib.collections.LineCollection(
        segments, array=z, cmap=cmap, norm=norm,
        linewidth=linewidth, alpha=alpha, **kwargs
    )

    if axes is None:
        axes = pyplot.gca()

    axes.add_collection(lc)
    axes.autoscale()

    return lc


def plot_roc(tpr, fpr, thresholds, subplots_kwargs=None,
             label_every=None, label_kwargs=None,
             fpr_label='False Positive Rate',
             tpr_label='True Positive Rate',
             luck_label='Luck',
             title='Receiver operating characteristic',
             **kwargs):

    if subplots_kwargs is None:
        subplots_kwargs = {}

    figure, axes = pyplot.subplots(1, 1, **subplots_kwargs)

    if 'lw' not in kwargs:
        kwargs['lw'] = 1

    axes.plot(fpr, tpr, **kwargs)

    if label_every is not None:
        if label_kwargs is None:
            label_kwargs = {}

        if 'bbox' not in label_kwargs:
            label_kwargs['bbox'] = dict(
                boxstyle='round,pad=0.5', fc='yellow', alpha=0.5,
            )

        for k in six.moves.range(len(tpr)):
            if k % label_every != 0:
                continue

            threshold = str(numpy.round(thresholds[k], 2))
            x = fpr[k]
            y = tpr[k]
            axes.annotate(threshold, (x, y), **label_kwargs)

    if luck_label is not None:
        axes.plot((0, 1), (0, 1), '--', color='Gray', label=luck_label)

    lc = colorline(fpr, tpr, thresholds, axes=axes)
    figure.colorbar(lc)

    axes.set_xlim([-0.05, 1.05])
    axes.set_ylim([-0.05, 1.05])

    axes.set_xlabel(fpr_label)
    axes.set_ylabel(tpr_label)

    axes.set_title(title)

    axes.legend(loc="lower right")

    return figure, axes