import torch
import pandas as pd
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler,StandardScaler,LabelEncoder
import math
import numpy as np
from plot_module import plot_matrix_grid
from sdv.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import TVAESynthesizer
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ProcessPoolExecutor
import time

continous_features = ["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_NOZZLE_DISTANCE [mm]", "CONTOUR_FOCAL [mm]"]
categorical_features=["MATERIAL_NAME_TULUS","NOZZLE_TYPE"]



def flatten_categorical_columns(df,columns,only_quality_cut=False):
    # Applicare One-Hot Encoding
    # Separiamo NOZZLE_TYPE in (NOZZLE_BASE_TYPE,NOZZLE_SIZE)
    # in questo modo possiamo supporre per per ogni BASE_TYPE, i vari SIZE abbiamo un ordinamento interno
    if 'NOZZLE_BASE_TYPE' in df.columns:
        df["NOZZLE_SIZE"] = df["NOZZLE_TYPE"].str.extract(r'([\d\.]+)').astype(float)
        df["NOZZLE_BASE_TYPE"] = df["NOZZLE_TYPE"].str.extract(r'([A-Z]+)')
        df2 = df.drop("NOZZLE_TYPE", axis=1)
    df_encoded = pd.get_dummies(df, columns=columns)
    df_encoded[df_encoded.select_dtypes(bool).columns] = df_encoded.select_dtypes(bool).astype(int)
    if only_quality_cut:
        df_encoded["QUALITY_CUT"]=df_encoded["QUALITY_CUT"].replace({"Good": 1, "Bad": 0})
        df_encoded=df_encoded.drop("DEFECT_TYPE",axis=1)
    else:
        df_encoded=df_encoded.drop("QUALITY_CUT",axis=1)
    df_encoded=df_encoded.copy().drop_duplicates(subset=[col for col in df_encoded.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
    print("example: \n")
    print(f"total samples in df_encoded: {len(df_encoded)}")
    print(f'total samples good in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]=="No Defects"])}')
    print(f'total samples bad in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]!="No Defects"])}')
    if "TECHNOLOGY_GAS" not in columns:
       df_encoded=df_encoded.copy().drop(["TECHNOLOGY_GAS"],axis=1)
    if "CONTOUR_LASER_MODE" not in columns:
       df_encoded=df_encoded.copy().drop(["CONTOUR_LASER_MODE"],axis=1)
    if "LASER_TYPE" not in columns:
       df_encoded=df_encoded.copy().drop(["LASER_TYPE"],axis=1)
    return df_encoded


def compute_jsd_matrixwise(real_matrix, gen_matrix):
    n = real_matrix.shape[0]
    m = gen_matrix.shape[0]
    distances = []
    for r in real_matrix:
        jsd_values = [jensenshannon(r, g) for g in gen_matrix]
        distances.append(np.mean(jsd_values))  # media per ogni riga reale

    return np.mean(distances)  # score che rappresenta la divergenza tra la distribuzione reale e il batch corrente



def normalize_rows(real_scaled,gen_scaled):
    # normalizzo le righe in modo che sommino ad 1 (distribuzioni di probabilità)
    row_sum_real = real_scaled.sum(axis=1) + 1e-8  # per evitare divisioni per zero
    row_sum_gen=gen_scaled.sum(axis=1) + 1e-8
    real_normalized = real_scaled.div(row_sum_real, axis=0)
    gen_normalized=gen_scaled.div(row_sum_gen,axis=0)
   
    return real_normalized,gen_normalized


def denormalize_rows(gen_normalized):
    row_sum_gen=gen_normalized.sum(axis=1) + 1e-8
    gen_scaled=gen_normalized.mul(row_sum_gen,axis=0)
    return gen_scaled

def inverse_scaling(gen_scaled,scaler):
    # Inverso lo scaling per ottenere i dati originali
    gen_inverse_scaled = pd.DataFrame(scaler.inverse_transform(gen_scaled), columns=gen_scaled.columns, index=gen_scaled.index)
    return gen_inverse_scaled

def fitness(tournament_candidates, gen_samples, real_samples,target_size):
    print("-----------------------------------fitness--------------------------------------------")
    scaler = MinMaxScaler() # transform into a probability-like distribution
    scaler.fit(real_samples)
    # Applica lo stesso scaling ai dati reali e generati
    real_scaled = pd.DataFrame(scaler.transform(real_samples.copy()), columns=real_samples.columns,index=real_samples.index)
    gen_scaled = pd.DataFrame(scaler.transform(gen_samples.copy()), columns=gen_samples.columns,index=gen_samples.index)
   
    print('\n')

    real_normalized,gen_normalized=normalize_rows(real_scaled,gen_scaled)
    fitness_scores = []

    for indices in tournament_candidates:
        gen_normalized_batch = gen_normalized.iloc[indices]
        js_distance = compute_jsd_matrixwise(real_normalized.values, gen_normalized_batch.values)
        fitness_scores.append([js_distance,indices])

    # Trova il valore minimo di JSD e gli indici corrispondenti
    min_div, min_jsd_indices = min(fitness_scores, key=lambda x: x[0])
    greedyInitialSubset=gen_normalized.copy().iloc[min_jsd_indices].reset_index(drop=True)
    greedy_corr=greedyInitialSubset[["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_FOCAL [mm]"]].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(greedy_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati sintetici')
    plt.savefig("greedy_initial_subset_corr.png")
    gen_normalized=gen_normalized.copy().drop(min_jsd_indices).reset_index(drop=True)
    print("min jsd:")  
    print(min_div)
    print(f"Greedy Initial Subset initial len:{len(greedyInitialSubset)}")
    init_len=len(greedyInitialSubset)
    greedy_scores=[]
    for i in range(len(gen_normalized)):
        candidate = gen_normalized.iloc[i]
        candidate = pd.DataFrame([gen_normalized.iloc[i].values], columns=gen_normalized.columns)
        new_gen_samples = pd.concat([greedyInitialSubset.copy(), candidate.copy()], axis=0, ignore_index=True)
        js_distance = compute_jsd_matrixwise(real_normalized.values, new_gen_samples.values)
        greedy_scores.append([js_distance, candidate.copy(), gen_normalized.iloc[i].name])
    
    

    # Ordina i candidati per distanza JSD crescente
    greedy_scores.sort(key=lambda x: x[0])

    # Seleziona i primi (target_size - init_len) candidati
    selected_candidates = greedy_scores[:int(target_size - init_len)]

    # Estrai solo i candidati selezionati (senza ripetere l'iterazione)
    greedyInitialSubset = pd.concat([greedyInitialSubset.copy(), 
                                    pd.DataFrame([candidate[1].iloc[0].values for candidate in selected_candidates], 
                                                columns=gen_normalized.columns)], 
                                axis=0, ignore_index=True)
    
    final_js_distance = compute_jsd_matrixwise(real_normalized.values, greedyInitialSubset.values)
    print(f"final jsd: {final_js_distance}")
    # Rimuovi i candidati selezionati da gen_normalized
    selected_indices = [candidate[2] for candidate in selected_candidates]
    selected_gen_samples = gen_samples.iloc[selected_indices]
    gen_normalized = gen_normalized.drop(selected_indices).reset_index(drop=True)
    
    return selected_gen_samples



def tournament_selection(gen_samples,real_samples,B,target_size):
    real_samples_size=len(real_samples)
    gen_samples_size=len(gen_samples)
    tournament_candidates = [ np.random.choice(gen_samples_size, real_samples_size, replace=False).tolist() for _ in range(B) ]
    fitness_scores=fitness(tournament_candidates,gen_samples,real_samples,target_size)
    
    return fitness_scores




real_data = pd.read_excel('./merged_files.xlsx')
gen_data = pd.read_excel('./synthetic_data_tvae(defects_2000_epochs).xlsx')
real_data=real_data[real_data["DEFECT_TYPE"]=='Plasma'].copy().reset_index(drop=True)
gen_data=gen_data[gen_data["DEFECT_TYPE"]=='Plasma'].copy().reset_index(drop=True)
gen_data['CONTOUR_LASER_MODE'] = 'GP'
gen_data['TECHNOLOGY_GAS'] = 'N2'
gen_data['QUALITY_CUT'] = 'Bad'
gen_data['LASER_TYPE']='YLS'
real_data=flatten_categorical_columns(real_data.copy(),categorical_features,only_quality_cut=False)
gen_data=flatten_categorical_columns(gen_data.copy(),categorical_features,only_quality_cut=False)
gen_data=gen_data.drop(['DEFECT_TYPE'],axis=1)
real_data=real_data.drop(["DEFECT_TYPE"],axis=1)
print(real_data.columns)
print(gen_data.columns)
print(real_data.shape)
print(gen_data.shape)
greedy_subset=tournament_selection(gen_data,real_data,2000,100)
plot_matrix_grid(greedy_subset,speed_is_dense=True,pressure_is_dense=True,title='greedy_subset')
plot_matrix_grid(real_data,speed_is_dense=True,pressure_is_dense=True,title='real_data')
greedy_corr=greedy_subset[["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_FOCAL [mm]"]].corr()
real_corr=real_data[["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]",
                          "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_FOCAL [mm]"]].corr()
plt.figure(figsize=(10, 8))
sns.heatmap(greedy_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Matrice di Correlazione dati sintetici')
plt.savefig("greedy_subset_corr.png")
plt.figure(figsize=(10, 8))
sns.heatmap(real_corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Matrice di Correlazione dati originali')
plt.savefig("real_corr.png")