import torch
import pandas as pd
import matplotlib
import ot
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler,StandardScaler,LabelEncoder
import math
import numpy as np
from plot_module import plot_matrix_grid
from sdv.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import TVAESynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot
import re
from scipy.stats import ks_2samp, chi2_contingency
from sklearn.metrics.pairwise import rbf_kernel
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import rbf_kernel
from scipy.spatial.distance import mahalanobis
from sklearn.decomposition import PCA
import prince
from sklearn.ensemble import IsolationForest

def flatten_categorical_columns(df,columns,only_quality_cut=False):
    # Applicare One-Hot Encoding
    # Separiamo NOZZLE_TYPE in (NOZZLE_BASE_TYPE,NOZZLE_SIZE)
    # in questo modo possiamo supporre per per ogni BASE_TYPE, i vari SIZE abbiamo un ordinamento interno

    df["NOZZLE_SIZE"] = df["NOZZLE_TYPE"].str.extract(r'([\d\.]+)').astype(float)
    df["NOZZLE_BASE_TYPE"] = df["NOZZLE_TYPE"].str.extract(r'([A-Z]+)')
    df2 = df.drop("NOZZLE_TYPE", axis=1)
    df_encoded = pd.get_dummies(df2, columns=columns)
    df_encoded[df_encoded.select_dtypes(bool).columns] = df_encoded.select_dtypes(bool).astype(int)
    if only_quality_cut:
        df_encoded["QUALITY_CUT"]=df_encoded["QUALITY_CUT"].replace({"Good": 1, "Bad": 0})
        df_encoded=df_encoded.drop("DEFECT_TYPE",axis=1)
    else:
        df_encoded=df_encoded.drop("QUALITY_CUT",axis=1)
    df_encoded=df_encoded.copy().drop_duplicates(subset=[col for col in df_encoded.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
    print("example: \n")
    print(f"total samples in df_encoded: {len(df_encoded)}")
    print(f'total samples good in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]=="No Defects"])}')
    print(f'total samples bad in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]!="No Defects"])}')
    if "TECHNOLOGY_GAS" not in columns:
       df_encoded=df_encoded.copy().drop(["TECHNOLOGY_GAS"],axis=1)
    if "CONTOUR_LASER_MODE" not in columns:
       df_encoded=df_encoded.copy().drop(["CONTOUR_LASER_MODE"],axis=1)
    return df_encoded

def sanitize_filename(filename):
    # Sostituisci caratteri non validi con underscore
    return re.sub(r'[^\w\-_\.]', '_', filename)

def frobenius_norm(A):
    return np.linalg.norm(A, 'fro')

def frobenius_max_distance(n):
    # Calcola la norma di Frobenius massima tra due matrici di correlazione di dimensione n x n
    return 2 * np.sqrt(n * (n - 1))


syn_data=pd.read_excel("synthetic_data_ctgan(defects_1000_epochs).xlsx").drop(columns=["is_syn"])
real_data=pd.read_excel("merged_files.xlsx").drop(columns=["CONTOUR_LASER_MODE","LASER_TYPE","TECHNOLOGY_GAS","QUALITY_CUT"])

final_data=pd.concat([real_data,syn_data],axis=0).reset_index(drop=True)
print(f"len before drop: {len(final_data)}")
final_data=final_data.copy().drop_duplicates(subset=[col for col in final_data.columns if col !='DEFECT_TYPE'],keep='first').reset_index(drop=True)
print(f"len after drop: {len(final_data)}")
#final_data.to_excel("final_data1(last).xlsx")

print(real_data.columns)
print(syn_data.columns)
print(f'syn: {syn_data["DEFECT_TYPE"].value_counts()}')
print(f'real: {real_data["DEFECT_TYPE"].value_counts()}')

numeric_cols=["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
"CONTOUR_FOCAL [mm]"]
categorical_cols=[col for col in final_data.columns if col not in numeric_cols]
metadata=Metadata.detect_from_dataframe(real_data)


print("\n=== PCA Analysis ===")
for defect in real_data["DEFECT_TYPE"].unique():

    scaler = StandardScaler()
    scaler.fit(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols])
    syn_norm = scaler.transform(syn_data[syn_data["DEFECT_TYPE"]==defect][numeric_cols].copy())
    real_norm = scaler.transform(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols].copy())


    # Fit PCA on the concatenated normalized dataset
    pca = PCA(n_components=2)
    real_pca = pca.fit_transform(real_norm)
    syn_pca = pca.transform(syn_norm)  # Usa lo stesso fit per comparabilità

    # Varianza spiegata
    explained_variance = pca.explained_variance_ratio_
    print(f"Explained variance by first 2 PCs: {explained_variance}")

    # PCA plot 2D
    plt.figure(figsize=(10, 6))
    plt.scatter(real_pca[:, 0], real_pca[:, 1], label=f'Real_{defect}', alpha=0.5, s=20)
    plt.scatter(syn_pca[:, 0], syn_pca[:, 1], label=f'Synthetic_{defect}', alpha=0.5, s=20)
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.title('PCA Projection (Real vs Synthetic)')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"./results_dim_red/PCA_projection_{defect}.png")
    plt.close()

for defect in real_data["DEFECT_TYPE"].unique():
    print(f"-------------------{defect}---------------------------------")
    scaler = StandardScaler()
    scaler.fit(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols])
    real_norm=real_data[real_data["DEFECT_TYPE"]==defect].drop(columns=["DEFECT_TYPE"],axis=1).copy()
    syn_norm=syn_data[syn_data["DEFECT_TYPE"]==defect].drop(columns=["DEFECT_TYPE"],axis=1).copy()
    real_norm[numeric_cols]=scaler.transform(real_norm[numeric_cols].copy())
    syn_norm[numeric_cols]=scaler.transform(syn_norm[numeric_cols].copy())

    famd = prince.FAMD(n_components=2, random_state=42)
    famd = famd.fit(real_norm)
    real_data_famd = famd.transform(real_norm)
    syn_data_famd=famd.transform(syn_norm)
    explained_variance=famd.eigenvalues_summary
    print(f"Explained variance by first 2 FAMDs: {explained_variance}")

    # Crea un grafico per visualizzare i dati proiettati sulle prime 2 componenti
    plt.figure(figsize=(10, 8))

    # Plotta i dati reali
    plt.scatter(real_data_famd[0], real_data_famd[1], label="Real Data", alpha=0.6, color='blue')

    # Plotta i dati sintetici
    plt.scatter(syn_data_famd[0], syn_data_famd[1], label="Synthetic Data", alpha=0.6, color='red')

    # Aggiungi etichette e titolo
    plt.xlabel('FAMD Component 1')
    plt.ylabel('FAMD Component 2')
    plt.title('Real vs Synthetic Data in 2D (FAMD)')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"./results_dim_red/FAMD_projection_{defect}.png")

