import torch
import pandas as pd
import matplotlib
import ot
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler,StandardScaler,LabelEncoder
import math
import numpy as np
from plot_module import plot_matrix_grid
from sdv.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import TVAESynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot
import re
from scipy.stats import ks_2samp, chi2_contingency
from sklearn.metrics.pairwise import rbf_kernel
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import rbf_kernel
from scipy.spatial.distance import mahalanobis
from sklearn.decomposition import PCA


def sanitize_filename(filename):
    # Sostituisci caratteri non validi con underscore
    return re.sub(r'[^\w\-_\.]', '_', filename)

def frobenius_norm(A):
    return np.linalg.norm(A, 'fro')

def frobenius_max_distance(n):
    # Calcola la norma di Frobenius massima tra due matrici di correlazione di dimensione n x n
    return 2 * np.sqrt(n * (n - 1))


syn_data=pd.read_excel("synthetic_data_ctgan(defects_1000_epochs).xlsx").drop(columns=["is_syn"])
real_data=pd.read_excel("merged_files.xlsx").drop(columns=["CONTOUR_LASER_MODE","LASER_TYPE","TECHNOLOGY_GAS","QUALITY_CUT"])
final_data=pd.concat([real_data,syn_data],axis=0).reset_index(drop=True)
print(f"len before drop: {len(final_data)}")
final_data=final_data.copy().drop_duplicates(subset=[col for col in final_data.columns if col !='DEFECT_TYPE'],keep='first').reset_index(drop=True)
print(f"len after drop: {len(final_data)}")
#final_data.to_excel("final_data1(last).xlsx")

print(real_data.columns)
print(syn_data.columns)
print(f'syn: {syn_data["DEFECT_TYPE"].value_counts()}')
print(f'real: {real_data["DEFECT_TYPE"].value_counts()}')

numeric_cols=["THICKNESS_TULUS [mm]","CONTOUR_SPEED [mm/min]","LASER_POWER [W]","CONTOUR_GAS_PRESSURE [bar]","CONTOUR_NOZZLE_DISTANCE [mm]",
"CONTOUR_FOCAL [mm]"]
categorical_cols=[col for col in final_data.columns if col not in numeric_cols]
metadata=Metadata.detect_from_dataframe(real_data)

# 1. Test di Kolmogorov-Smirnov per variabili numeriche
print("\n=== Test di Kolmogorov-Smirnov ===")
for col in numeric_cols:
    stat, p_value = ks_2samp(real_data[col].dropna(), syn_data[col].dropna())
    print(f"{col}: statistic={stat:.4f}, p-value={p_value:.4f}")

# 2. Test del Chi-quadrato per variabili categoriche
print("\n=== Test del Chi-quadrato === ")
for col in categorical_cols:
    real_counts = real_data[col].value_counts()
    syn_counts = syn_data[col].value_counts()
    all_categories = real_counts.index.union(syn_counts.index)
    real_freq = real_counts.reindex(all_categories, fill_value=0)
    syn_freq = syn_counts.reindex(all_categories, fill_value=0)
    contingency_table = pd.DataFrame({'Real': real_freq, 'Synthetic': syn_freq})
    chi2_stat, p_value, _, _ = chi2_contingency(contingency_table.T)
    print(f"{col}: chi2_stat={chi2_stat:.4f}, p-value={p_value:.4f}")




print("\n=== PCA Analysis ===")
for defect in real_data["DEFECT_TYPE"].unique():

    scaler = StandardScaler()
    scaler.fit(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols])
    syn_norm = scaler.transform(syn_data[syn_data["DEFECT_TYPE"]==defect][numeric_cols].copy())
    real_norm = scaler.transform(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols].copy())


    # Fit PCA on the concatenated normalized dataset
    pca = PCA(n_components=2)
    real_pca = pca.fit_transform(real_norm)
    syn_pca = pca.transform(syn_norm)  # Usa lo stesso fit per comparabilità

    # Varianza spiegata
    explained_variance = pca.explained_variance_ratio_
    print(f"Explained variance by first 2 PCs: {explained_variance}")

    # PCA plot 2D
    plt.figure(figsize=(10, 6))
    plt.scatter(real_pca[:, 0], real_pca[:, 1], label=f'Real_{defect}', alpha=0.5, s=20)
    plt.scatter(syn_pca[:, 0], syn_pca[:, 1], label=f'Synthetic_{defect}', alpha=0.5, s=20)
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.title('PCA Projection (Real vs Synthetic)')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"./results/PCA_projection_{defect}.png")
    plt.close()

scaler = StandardScaler()
scaler.fit(real_data[numeric_cols])
syn_norm = scaler.transform(syn_data[numeric_cols].copy())
real_norm = scaler.transform(real_data[numeric_cols].copy())
print("\n=== Test distanza di Wasserstein ===")
M=ot.dist(syn_norm,real_norm)
M=M/M.max()
wasserstein_distance = ot.emd2([], [], M)
print("Distanza di Wasserstein:", wasserstein_distance)

print("\n=== Norma di Frobenius tra Matrici di Correlazione ===")
syn_corr_matrix=syn_data[numeric_cols].corr()
plt.figure(figsize=(10, 8))
sns.heatmap(syn_corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Correlation matrix synthetic data')
plt.savefig(f"./results/Matrice_di_Correlazione_Synthetic.png")
plt.close()
real_data=real_data[['THICKNESS_TULUS [mm]', 'CONTOUR_SPEED [mm/min]', 'LASER_POWER [W]',
       'CONTOUR_GAS_PRESSURE [bar]', 'CONTOUR_NOZZLE_DISTANCE [mm]',
       'CONTOUR_FOCAL [mm]', 'NOZZLE_TYPE', 'MATERIAL_NAME_TULUS',
       'DEFECT_TYPE']]
syn_corr_matrix=syn_data[numeric_cols].corr().values
real_corr_matrix=real_data[numeric_cols].corr()
plt.figure(figsize=(10, 8))
sns.heatmap(real_corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Correlation matrix real data')
plt.savefig(f"./results/Matrice_di_Correlazione_Real.png")
plt.close()
real_corr_matrix=real_data[numeric_cols].corr().values
frob=frobenius_norm(syn_corr_matrix-real_corr_matrix)


n = len(numeric_cols)
max_distance = frobenius_max_distance(n)
print(f"La Frobenius distance è di: {frob}")
print("La minima distanza (norma di Frobenius) è: 0 (matrici uguali)")
print(f"La massima distanza (norma di Frobenius) per una matrice di dimensione {n}x{n} è: {max_distance} (matrici completamente opposte")
print(f"Compatibilità dei dati sintetici: {round((1-(frob/max_distance))*100,2)} %")
"""
print("---------------------------------------------")
print("Evaluating quality trought basic validity check.. \n")
quality_report = evaluate_quality(real_data, syn_data, metadata,verbose=True)
print(quality_report)
print("-----------------------------------------------")
print("Running diagnostics (statistical similarity) .. \n")
diagnostic = run_diagnostic(real_data, syn_data, metadata,verbose=True)
print(diagnostic)
print("-----------------------------------------------")
print("Plotting data... \n")
for col in real_data.columns:
    fig = get_column_plot(
        real_data=real_data,
        synthetic_data=syn_data,
        metadata=metadata,
        column_name=col
    )
    fig.write_image(f'./evaluated_data_report/{sanitize_filename(col)}.png')
    fig.show()
"""
diff=real_corr_matrix-syn_corr_matrix
plt.figure(figsize=(10, 8))
sns.heatmap(diff, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
plt.title('Matrice di Correlazione differenza (reale vs syn)')
plt.savefig(f"./results/Matrice_di_Correlazione_differenza.png")
plt.close()
plot_matrix_grid("syn_data_plot____",syn_data, cols_per_row=3, fig_width=18, row_height=5,speed_is_dense=True,pressure_is_dense=True,focal_is_dense=True)
plot_matrix_grid("real_data_plot____",real_data, cols_per_row=3, fig_width=18, row_height=5,speed_is_dense=True,pressure_is_dense=True,focal_is_dense=True)

#-------------------------MAHALANOBIS DISTANCE----------------------------------

print("\n=== Distanza di Mahalanobis ===")

# Calcola la distanza di Mahalanobis per ciascun punto sintetico rispetto alla distribuzione reale
mahalanobis_distances = {defect:[] for defect in real_data["DEFECT_TYPE"].unique()}
for defect in real_data["DEFECT_TYPE"].unique():

    mean_real = np.mean(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols], axis=0)
    cov_real = np.cov(real_data[real_data["DEFECT_TYPE"]==defect][numeric_cols], rowvar=False)
    inv_cov_real = np.linalg.inv(cov_real)

    for index, row in syn_data[syn_data["DEFECT_TYPE"]==defect][numeric_cols].iterrows():
        distance = mahalanobis(row, mean_real, inv_cov_real)
        mahalanobis_distances[defect].append(distance)

for defect in real_data["DEFECT_TYPE"].unique():
    mean_distance=np.mean(mahalanobis_distances[defect])
    print(f"Mean Mahalanobis distance of class {defect}: {mean_distance}")