import torch
import pandas as pd
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler,StandardScaler,LabelEncoder
import math
import numpy as np
from plot_module import plot_matrix_grid
from sdv.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import TVAESynthesizer
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ProcessPoolExecutor
import time
from tqdm import tqdm
from utils import reverse_get_dummies
from plot_module import plot_matrix_grid

all_continous_features = ["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_NOZZLE_DISTANCE [mm]", "CONTOUR_FOCAL [mm]"]
categorical_features=["MATERIAL_NAME_TULUS","NOZZLE_TYPE"]

label_map = {
    'CONTOUR_SPEED [mm/min]': 'Contour Speed (mm/min)',
    'LASER_POWER [W]': 'Laser Power (W)',
    'CONTOUR_GAS_PRESSURE [bar]': 'Contour Gas Pressure (bar)',
    'CONTOUR_FOCAL [mm]': 'Contour Focal (mm)',
    'THICKNESS_TULUS [mm]': 'Thickness Tulus (mm)',
}



def flatten_categorical_columns(df,columns,only_quality_cut=False):
    # Applicare One-Hot Encoding
    # Separiamo NOZZLE_TYPE in (NOZZLE_BASE_TYPE,NOZZLE_SIZE)
    # in questo modo possiamo supporre per per ogni BASE_TYPE, i vari SIZE abbiamo un ordinamento interno
    if 'NOZZLE_BASE_TYPE' in df.columns:
        df["NOZZLE_SIZE"] = df["NOZZLE_TYPE"].str.extract(r'([\d\.]+)').astype(float)
        df["NOZZLE_BASE_TYPE"] = df["NOZZLE_TYPE"].str.extract(r'([A-Z]+)')
        df = df.drop("NOZZLE_TYPE", axis=1)
    df_encoded = pd.get_dummies(df, columns=columns)
    df_encoded[df_encoded.select_dtypes(bool).columns] = df_encoded.select_dtypes(bool).astype(int)
    if only_quality_cut:
        df_encoded["QUALITY_CUT"]=df_encoded["QUALITY_CUT"].replace({"Good": 1, "Bad": 0})
        df_encoded=df_encoded.drop("DEFECT_TYPE",axis=1)
    else:
        df_encoded=df_encoded.drop("QUALITY_CUT",axis=1)
    df_encoded=df_encoded.copy().drop_duplicates(subset=[col for col in df_encoded.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
    print("example: \n")
    print(f"total samples in df_encoded: {len(df_encoded)}")
    print(f'total samples good in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]=="No Defects"])}')
    print(f'total samples bad in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]!="No Defects"])}')
    if "TECHNOLOGY_GAS" not in columns:
       df_encoded=df_encoded.copy().drop(["TECHNOLOGY_GAS"],axis=1)
    if "CONTOUR_LASER_MODE" not in columns:
       df_encoded=df_encoded.copy().drop(["CONTOUR_LASER_MODE"],axis=1)
    if "LASER_TYPE" not in columns:
       df_encoded=df_encoded.copy().drop(["LASER_TYPE"],axis=1)
    return df_encoded



def frobenius_norm(A):
    return np.linalg.norm(A, 'fro')



def inverse_scaling(gen_scaled,scaler):
    # Inverso lo scaling per ottenere i dati originali
    gen_inverse_scaled = pd.DataFrame(scaler.inverse_transform(gen_scaled), columns=gen_scaled.columns, index=gen_scaled.index)
    return gen_inverse_scaled

def fitness(tournament_candidates, gen_samples, real_samples,target_size,continous_features,defect):
    print("-----------------------------------fitness--------------------------------------------")
    fitness_scores = []
    real_corr=real_samples.corr().fillna(0)
    for indices in tournament_candidates:
        gen_batch = gen_samples.iloc[indices]
        # Calcola la matrice di correlazione per il batch generato
        gen_batch_corr = gen_batch.corr().fillna(0)
        frob_distance = frobenius_norm(real_corr.values - gen_batch_corr.values)
        fitness_scores.append([frob_distance, indices])

    # Trova il valore minimo di JSD e gli indici corrispondenti
    

    min_div, min_frob_indices = min(fitness_scores, key=lambda x: x[0])
    greedyInitialSubset=gen_samples.copy().iloc[min_frob_indices].reset_index(drop=True)
    gen_samples=gen_samples.copy().drop(min_frob_indices)

    print("min frobenius norm:")  
    print(min_div)
    print(f"Greedy Initial Subset initial len:{len(greedyInitialSubset)}")
    init_len=len(greedyInitialSubset)
    greedy_corr=greedyInitialSubset.corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(greedy_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati sintetici')
    plt.savefig(f"greedy_results/greedy_initial_subset_corr_{defect}.png")
    greedy_scores=[]
    for j in tqdm(range(target_size-init_len)):
        greedy_scores=[]
        for i in range(len(gen_samples)):
            candidate = gen_samples.iloc[i]
            candidate = pd.DataFrame([gen_samples.iloc[i].values], columns=gen_samples.columns)
            new_gen_samples = pd.concat([greedyInitialSubset.copy(), candidate.copy()], axis=0, ignore_index=True)
            new_gen_samples_corr = new_gen_samples.corr().fillna(0)
            # le colonne costanti possiamo fare in modo che non influiscano sul calcolo della matrice di correlazione
            frobenius_distance = frobenius_norm(real_corr.values - new_gen_samples_corr.values)
            greedy_scores.append([frobenius_distance, candidate.copy(), gen_samples.iloc[i].name])
     
        
        min_frob_dist,candidate,min_frob_indices=min(greedy_scores,key=lambda x: x[0])
        greedyInitialSubset=pd.concat([greedyInitialSubset.copy(),candidate.copy()],ignore_index=True)
        gen_samples=gen_samples.copy().drop(min_frob_indices)

    greedyInitialSubset_corr = greedyInitialSubset.corr().fillna(0)
    
    final_frob_distance = frobenius_norm(real_corr.values - greedyInitialSubset_corr.values)
    print(f"final frobenius distance: {final_frob_distance}")
  
    return greedyInitialSubset



def tournament_selection(gen_samples,real_samples,B,target_size,continous_features,defect):
    real_samples_size=len(real_samples)
    gen_samples_size=len(gen_samples)
    tournament_candidates = [ np.random.choice(gen_samples_size, real_samples_size, replace=False).tolist() for _ in range(B) ]
    fitness_scores=fitness(tournament_candidates,gen_samples,real_samples,target_size,continous_features,defect)
    
    return fitness_scores


all_cont_features = ["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", 'CONTOUR_NOZZLE_DISTANCE [mm]', "CONTOUR_FOCAL [mm]"]

full_real_data = pd.read_excel('./merged_files.xlsx')
full_gen_data = pd.read_excel('./synthetic_data_tvae(defects_2000_epochs).xlsx')

full_gen_data=full_gen_data.drop_duplicates(subset=[col for col in full_gen_data.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)

print(full_real_data.columns)

defects_params={
    'No Defects': {'target_size':len(full_real_data[full_real_data["DEFECT_TYPE"]=='No Defects'])*3,'batch_size':1000},
    'Burr':{'target_size':len(full_real_data[full_real_data["DEFECT_TYPE"]=='Burr'])*3,'batch_size':1000},
    'Cutting loss':{'target_size':len(full_real_data[full_real_data["DEFECT_TYPE"]=='Cutting loss'])*6,'batch_size':1000},
    'Plasma':{'target_size':len(full_real_data[full_real_data["DEFECT_TYPE"]=='Plasma'])*4,'batch_size':1000},
    'Cutting torn':{'target_size':len(full_real_data[full_real_data["DEFECT_TYPE"]=='Cutting torn'])*4,'batch_size':1000},
    
}


for defect in defects_params.keys():
    print(f"defect: {defect}")
    continous_features = all_cont_features.copy()
    real_data=full_real_data[full_real_data["DEFECT_TYPE"]==defect].copy().reset_index(drop=True)
    gen_data=full_gen_data[full_gen_data["DEFECT_TYPE"]==defect].copy().reset_index(drop=True)
    gen_data['CONTOUR_LASER_MODE'] = 'GP'
    gen_data['TECHNOLOGY_GAS'] = 'N2'
    gen_data['QUALITY_CUT'] = 'Bad'
    gen_data['LASER_TYPE']='YLS'
    print(f"real_data columns: {real_data.columns}")
    print(f"gen_data columns: {gen_data.columns}")
    print(f"real_data shape: {real_data.shape}")
    print(f"gen_data shape: {gen_data.shape}")
    # rimuoviamo i duplicati tra reali e sintetici
    df_diff = pd.merge(gen_data, real_data, on=list(gen_data.columns), how='outer', indicator=True)
    gen_data = df_diff[df_diff['_merge'] == 'left_only'].drop(columns=['_merge']).reset_index(drop=True)

    gen_data=gen_data.drop_duplicates(subset=[col for col in gen_data.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT','CONTOUR_NOZZLE_DISTANCE [mm]']], keep='first', inplace=False).reset_index(drop=True)
    real_data=real_data.drop_duplicates(subset=[col for col in real_data.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT','CONTOUR_NOZZLE_DISTANCE [mm]']], keep='first', inplace=False).reset_index(drop=True)
    print(f"real_data shape: {real_data.shape}")
    print(f"gen_data shape: {gen_data.shape}")
    real_data=flatten_categorical_columns(real_data.copy(),categorical_features,only_quality_cut=False)
    gen_data=flatten_categorical_columns(gen_data.copy(),categorical_features,only_quality_cut=False)
    gen_data=gen_data.drop(['DEFECT_TYPE'],axis=1)
    real_data=real_data.drop(["DEFECT_TYPE"],axis=1)
    print(f"real_data columns: {real_data.columns}")
    print(f"gen_data columns: {gen_data.columns}")
    absent_columns=[col for col in real_data.columns if col not in gen_data.columns]
    print(f"continous features: {continous_features}")
    print(f"real_data columns: {real_data.columns}")
    print(f"gen_data columns: {gen_data.columns}")
    print(f"real_data shape: {real_data.shape}")
    print(f"gen_data shape: {gen_data.shape}")
    greedy_subset=tournament_selection(gen_data,real_data.copy().drop(absent_columns,axis=1),defects_params[defect]['batch_size'],defects_params[defect]['target_size'],continous_features,defect)
    print(f"greedy_subset shape: {greedy_subset.shape}")
    print(f"greedy_subset columns: {greedy_subset.columns}")
    #plot_matrix_grid(greedy_subset,speed_is_dense=True,pressure_is_dense=True,title='greedy_results/greedy_subset')
    #plot_matrix_grid(real_data,speed_is_dense=True,pressure_is_dense=True,title='greedy_results/real_data')
    
    real_corr=real_data.corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(real_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati reali')
    #plt.savefig(f"greedy_results/real_full_corr_{defect}.png")

    real_corr=real_data[continous_features].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(real_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati reali')
    #plt.savefig(f"greedy_results/real_corr_{defect}.png")


    greedy_corr=greedy_subset.corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(greedy_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati sintetici')
    #plt.savefig(f"greedy_results/greedy_subset_full_corr_{defect}.png")

    greedy_corr=greedy_subset[continous_features].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(greedy_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati sintetici')
    #plt.savefig(f"greedy_results/greedy_subset_corr_{defect}.png")


   
    # plotting bivariate/multivariate correlations

    features_to_plot = [
        "THICKNESS_TULUS [mm]",
        'CONTOUR_SPEED [mm/min]',
        "LASER_POWER [W]",
        'CONTOUR_GAS_PRESSURE [bar]',
        'CONTOUR_FOCAL [mm]',
        
    ]

    greedy_subset=reverse_get_dummies(greedy_subset.copy(),categorical_features)
    print(greedy_subset.columns)
    print(real_data.columns)
    # === Scatter matrix con etichette comprensive di unità ===
    for name, data in [('Synthetic', greedy_subset), ('Real', real_data)]:
        df_scatter = data[features_to_plot].copy()
        g = sns.pairplot(df_scatter)
        # Imposta etichette con unità per ogni asse
        for i, var in enumerate(features_to_plot):
            for ax in g.axes[i, :]:
                ax.set_ylabel(label_map[var])
            for ax in g.axes[:, i]:
                ax.set_xlabel(label_map[var])
        plt.suptitle(f'Scatter Matrix - {name} Data', y=1.02)
        plt.tight_layout()
        plt.savefig(f'greedy_results/scatter_{name.lower()}_{defect}.png')

    
    plot_matrix_grid(greedy_subset,speed_is_dense=True,pressure_is_dense=True,title=f'greedy_results/greedy_subset_{defect}')
    #greedy_subset.to_excel(f"greedy_subset_{defect}.xlsx",index=False)