import torch
import pandas as pd
import matplotlib

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler,StandardScaler,LabelEncoder
import math
import numpy as np
from plot_module import plot_matrix_grid
from sdv.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import TVAESynthesizer,GaussianCopulaSynthesizer,CTGANSynthesizer,CopulaGANSynthesizer


output_dir='results_ctgan'

def flatten_categorical_columns(df,columns,only_quality_cut=False):
    # Applicare One-Hot Encoding
    # Separiamo NOZZLE_TYPE in (NOZZLE_BASE_TYPE,NOZZLE_SIZE)
    # in questo modo possiamo supporre per per ogni BASE_TYPE, i vari SIZE abbiamo un ordinamento interno

    df["NOZZLE_SIZE"] = df["NOZZLE_TYPE"].str.extract(r'([\d\.]+)').astype(float)
    df["NOZZLE_BASE_TYPE"] = df["NOZZLE_TYPE"].str.extract(r'([A-Z]+)')
    df2 = df.drop("NOZZLE_TYPE", axis=1)
    df_encoded = pd.get_dummies(df2, columns=columns)
    df_encoded[df_encoded.select_dtypes(bool).columns] = df_encoded.select_dtypes(bool).astype(int)
    if only_quality_cut:
        df_encoded["QUALITY_CUT"]=df_encoded["QUALITY_CUT"].replace({"Good": 1, "Bad": 0})
        df_encoded=df_encoded.drop("DEFECT_TYPE",axis=1)
    else:
        df_encoded=df_encoded.drop("QUALITY_CUT",axis=1)
    df_encoded=df_encoded.copy().drop_duplicates(subset=[col for col in df_encoded.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
    print("example: \n")
    print(f"total samples in df_encoded: {len(df_encoded)}")
    print(f'total samples good in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]=="No Defects"])}')
    print(f'total samples bad in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]!="No Defects"])}')
    if "TECHNOLOGY_GAS" not in columns:
       df_encoded=df_encoded.copy().drop(["TECHNOLOGY_GAS"],axis=1)
    if "CONTOUR_LASER_MODE" not in columns:
       df_encoded=df_encoded.copy().drop(["CONTOUR_LASER_MODE"],axis=1)
    return df_encoded





merged_df2=pd.read_excel("./merged_files.xlsx")
merged_df2=merged_df2[merged_df2["THICKNESS_TULUS [mm]"] != 4.0].copy()

data=merged_df2[["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
"CONTOUR_FOCAL [mm]","NOZZLE_TYPE",
"MATERIAL_NAME_TULUS",
"DEFECT_TYPE"]]
print(data["DEFECT_TYPE"].value_counts())
defects_list=data["DEFECT_TYPE"].unique()
discrete_columns = ["NOZZLE_TYPE","MATERIAL_NAME_TULUS"]

defects_params={
    'Plasma':{'batch':6,'n_samples':len(data[data["DEFECT_TYPE"]=='Plasma']),'epochs':2000,'l2scale':1e-5,'population_factor':0.5},
    'Burr':{'batch':20,'n_samples':int(len(data[data["DEFECT_TYPE"]=='Burr'])),'epochs':1000,'l2scale':1e-5,'population_factor':1.0},
    'Cutting loss':{'batch':6,'n_samples':len(data[data["DEFECT_TYPE"]=='Cutting loss']),'epochs':2000,'l2scale':1e-5,'population_factor':0.5},
    'Cutting torn':{'batch':10,'n_samples':len(data[data["DEFECT_TYPE"]=='Cutting torn']),'epochs':2000,'l2scale':1e-5,'population_factor':0.65},
    'No Defects': {'batch':18,'n_samples':int(len(data[data["DEFECT_TYPE"]=='No Defects'])),'epochs':1000,'l2scale':1e-5,'population_factor':1.0}
}

rounding_factors={
    "THICKNESS_TULUS [mm]":1,"CONTOUR_SPEED [mm/min]":float(1/50),"LASER_POWER [W]":float(1/1000),
    "CONTOUR_GAS_PRESSURE [bar]":2,"CONTOUR_NOZZLE_DISTANCE [mm]":10,"CONTOUR_FOCAL [mm]":2
}


n_generations=1000
synthetic_dt_ctgan=pd.DataFrame(columns=data.columns)


final_data2=pd.DataFrame(columns=data.columns)

print(f"synthetic_dt_ctgan start len: {len(synthetic_dt_ctgan)}")
n_epochs=2000
batch_size=16
metadata=Metadata.detect_from_dataframe(data)
print(metadata)

# per adesso aumentiamo la quantità della classe di difetti minoritaria
data1=data.copy()
print(len(data1))
for defect in defects_params.keys():
  
    synthetizer = CTGANSynthesizer(metadata=metadata,epochs=defects_params[defect]['epochs'], batch_size=defects_params[defect]['batch'],verbose=True,pac=defects_params[defect]['batch'])
    print(defect)
    data_corr_matrix=data1[data1["DEFECT_TYPE"]==defect][["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
    "CONTOUR_FOCAL [mm]",]].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(data_corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione dati originali')
    title=str(plt.title).replace(" ","_")
    plt.savefig(f'{output_dir}/Matrice_di_Correlazione_dati_originali_{defect}')
    print("original data plot:")
    plot_matrix_grid(data[data["DEFECT_TYPE"]==defect],cols_per_row=3, fig_width=18, row_height=5,speed_is_dense=True,pressure_is_dense=True,focal_is_dense=True,title=f"{output_dir}/original_data_plot_{defect}")

    scal=StandardScaler()
    data1_norm=data1.copy()
    data1_norm[["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
    "CONTOUR_FOCAL [mm]"]]=scal.fit_transform(data1_norm[["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
    "CONTOUR_FOCAL [mm]"]])
    #ctgan.fit(data1.copy())
    #ctgan.save(f'ctgan_2000_epochs_{defect}.pk1')
    synthetizer.fit(data1[data1["DEFECT_TYPE"]==defect])
    synthetic_data_ctgan = synthetizer.sample(int(defects_params[defect]['population_factor']*n_generations))


    print(f"generated len synthetic_ctgan: {len(synthetic_data_ctgan)}")
    synthetic_data_ctgan=pd.DataFrame(synthetic_data_ctgan.copy())
    syn_corr_matrix_ctgan = synthetic_data_ctgan[["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]", "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_NOZZLE_DISTANCE [mm]",
    "CONTOUR_FOCAL [mm]", ]].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(syn_corr_matrix_ctgan, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione CTGAN')
    plt.savefig(f"{output_dir}/Matrice_di_Correlazione_CTGAN_{defect}")
    print("Synthetic_data_ctgan plot: \n")
    plot_matrix_grid( synthetic_data_ctgan, cols_per_row=3, fig_width=18, row_height=5, speed_is_dense=True, pressure_is_dense=True, focal_is_dense=True,title=f"{output_dir}/Synthetic_data_ctgan_plot_{defect}")
    print(data1.columns)
    for col in synthetic_data_ctgan.columns:
        if col not in discrete_columns+["DEFECT_TYPE"]:
          min_vals = data1[col].min()
          print(min_vals)
          max_vals = data1[col].max()
          if col == "CONTOUR_SPEED [mm/min]"  or col == "LASER_POWER [W]":
            log_min = np.log10(min_vals)
            log_max = np.log10(max_vals)
            if log_min == log_max:
              min_tolerated = min_vals - 0.1 * min_vals
              max_tolerated = max_vals + 0.1 * max_vals
            else:
              log_range = log_max - log_min
              #  5% del log-range per il minimo
              log_tol_min = 0.05 * log_range
              #  10% del log-range per il massimo
              log_tol_max = 0.1 * log_range
              min_tolerated = 10**(log_min - log_tol_min)
              max_tolerated = 10**(log_max + log_tol_max)
          elif min_vals==max_vals:
            min_tolerated = min_vals - 0.1 * min_vals
            max_tolerated = max_vals + 0.1 * max_vals
          else:
            min_tolerated = min_vals - 0.05 * (max_vals - min_vals)
            max_tolerated = max_vals + 0.1 * (max_vals - min_vals)

          synthetic_data_ctgan[col]= np.round(synthetic_data_ctgan[col] * rounding_factors[col]) / rounding_factors[col]
          synthetic_data_ctgan = synthetic_data_ctgan[(synthetic_data_ctgan[col] >= min_tolerated) & (synthetic_data_ctgan[col] <= max_tolerated)].copy()

    synthetic_dt_ctgan=pd.concat([synthetic_dt_ctgan.copy(), synthetic_data_ctgan.copy()], axis=0,ignore_index=True)

    syn_corr_matrix_ctgan = synthetic_data_ctgan[["THICKNESS_TULUS [mm]", "CONTOUR_SPEED [mm/min]", "LASER_POWER [W]", "CONTOUR_GAS_PRESSURE [bar]", "CONTOUR_NOZZLE_DISTANCE [mm]",
      "CONTOUR_FOCAL [mm]", ]].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(syn_corr_matrix_ctgan, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
    plt.title('Matrice di Correlazione CTGAN post pulizia')
    plt.savefig(f"{output_dir}/Matrice_di_Correlazione_CTGAN_post_pulizia_{defect}")
    print(f"synthetic_data_CTGAN cleaned len: {len(synthetic_data_ctgan)}")
    print("Synthetic_data_CTGAN cleaned plot: \n")
    print(synthetic_data_ctgan.iloc[0])
    plot_matrix_grid( synthetic_data_ctgan, cols_per_row=3, fig_width=18, row_height=5, speed_is_dense=True, pressure_is_dense=True, focal_is_dense=True,title=f"{output_dir}/Synthetic_data_CTGAN_cleaned_plot")
    synthetic_dt_ctgan=synthetic_dt_ctgan.copy().drop_duplicates(subset=[col for col in synthetic_dt_ctgan.columns if col not in ['DEFECT_TYPE','QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)





final_data2=pd.concat([synthetic_dt_ctgan.copy().reset_index(drop=True), data.copy().reset_index(drop=True)], axis=0).reset_index(drop=True)
print(f"final_data2 columns:\n {final_data2.columns}")

final_data2=final_data2.copy().drop_duplicates(subset=[col for col in synthetic_dt_ctgan.columns if col not in ['DEFECT_TYPE']], keep='first', inplace=False).reset_index(drop=True)
print("final_data_ctgan plot:")
plot_matrix_grid(final_data2, cols_per_row=3, fig_width=18, row_height=5,speed_is_dense=True,pressure_is_dense=True,focal_is_dense=True,title=f"{output_dir}/final_data_ctgan plot")
print(f'final_data_ctgan{final_data2["DEFECT_TYPE"].value_counts()}')
corr_matrix_final_ctgan = final_data2[["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
      "CONTOUR_FOCAL [mm]",]].corr()
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix_final_ctgan, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Matrice di Correlazione CTGAN finale')
plt.savefig(f"{output_dir}/Matrice_di_Correlazione_CTGAN_finale")
plt.show()
print(final_data2.columns)
synthetic_dt_ctgan=synthetic_dt_ctgan.copy().drop_duplicates(subset=[col for col in synthetic_dt_ctgan.columns if col not in ['DEFECT_TYPE','QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
synthetic_dt_ctgan.to_excel(f"{output_dir}/synthetic_data_ctgan(defects_{n_epochs}_epochs).xlsx", index=False)
final_data2.to_excel(f"{output_dir}/final_data2_(defects_{n_epochs}_epochs).xlsx", index=False)

