#from ctgan import CTGAN
from sdv.single_table import CTGANSynthesizer
import pandas as pd
from sdv.sampling import Condition
from sdv.metadata import Metadata
from sdv.constraints import create_custom_constraint
num_samples = 10

final_data=pd.read_excel('outlier_detection_output/final_data_cleaned.xlsx')
final_data=final_data[final_data['DEFECT_TYPE'] == 'No Defects'].drop(columns=['DEFECT_TYPE'], axis=1)
metadata=Metadata.detect_from_dataframe(final_data)

metadata.update_column(
    column_name='THICKNESS_TULUS [mm]',  # Sostituisci con il nome effettivo
    sdtype='numerical'  # Cambia da 'categorical' a 'numerical'
)
metadata.update_column(
    column_name='CONTOUR_SPEED [mm/min]',  # Sostituisci con il nome effettivo
    sdtype='numerical'  # Cambia da 'categorical' a 'numerical'
)
print(metadata.to_dict())
material_type=final_data['MATERIAL_NAME_TULUS'].unique()[0]  # Assuming you want to sample based on the first unique material type

final_data_subset_material = final_data[final_data['MATERIAL_NAME_TULUS'] == material_type].drop(columns=['MATERIAL_NAME_TULUS'], axis=1)

ctgan=CTGANSynthesizer(metadata,epochs=2000, batch_size=16, generator_dim=(32, 32), discriminator_dim=(32, 32),pac=8,generator_lr=1e-4,  # Learning rate più basso
    discriminator_lr=1e-4,)
def nozzle_speed_constraint(column_names):
    def is_valid(table_data):
        mask = table_data['NOZZLE_TYPE'] == 'SMT 5.0'
        # Per SMT 5.0, velocità deve essere <= 1800
        valid_speed = ~mask | (table_data['CONTOUR_SPEED [mm/min]'] <= 1800)
        return valid_speed
    return is_valid

constraint = create_custom_constraint(
    is_valid_fn=nozzle_speed_constraint,
    column_names=['NOZZLE_TYPE', 'CONTOUR_SPEED [mm/min]']
)

ctgan.add_constraints([constraint])
ctgan.fit(final_data)
print("CTGAN condition: 'MATERIAL_NAME_TULUS' with value:", material_type)
cond=Condition(column_values={'MATERIAL_NAME_TULUS': material_type},num_rows=num_samples)
conditional_samples = ctgan.sample_from_conditions(conditions=[cond])
print(conditional_samples)