import pandas as pd


def reverse_get_dummies(df_encoded, original_columns_prefixes,keep_old_columns=False):
    df_restored = df_encoded.copy()
    for prefix in original_columns_prefixes:
        # Trova tutte le colonne one-hot relative a quel prefisso
        cols = [col for col in df_encoded.columns if col.startswith(prefix + "_")]
        if not cols:
            continue
        # Ricostruisci la colonna categoriale
        df_restored[prefix] = df_encoded[cols].idxmax(axis=1).str[len(prefix)+1:]
        # Rimuovi le colonne one-hot
        if not keep_old_columns:
            df_restored = df_restored.drop(columns=cols)
    return df_restored


def flatten_categorical_columns(df,columns,only_quality_cut=False,drop_duplicates=False):
    # Applicare One-Hot Encoding
    # Separiamo NOZZLE_TYPE in (NOZZLE_BASE_TYPE,NOZZLE_SIZE)
    # in questo modo possiamo supporre per per ogni BASE_TYPE, i vari SIZE abbiamo un ordinamento interno
    if "NOZZLE_BASE_TYPE" in df.columns:
      df["NOZZLE_SIZE"] = df["NOZZLE_TYPE"].str.extract(r'([\d\.]+)').astype(float)
      df["NOZZLE_BASE_TYPE"] = df["NOZZLE_TYPE"].str.extract(r'([A-Z]+)')
      df = df.drop("NOZZLE_TYPE", axis=1)
    df_encoded = pd.get_dummies(df, columns=columns)
    df_encoded[df_encoded.select_dtypes(bool).columns] = df_encoded.select_dtypes(bool).astype(int)
    if only_quality_cut:
        df_encoded["QUALITY_CUT"]=df_encoded["QUALITY_CUT"].replace({"Good": 1, "Bad": 0})
        df_encoded=df_encoded.drop("DEFECT_TYPE",axis=1)
    else:
        try:
            df_encoded=df_encoded.drop("QUALITY_CUT",axis=1)
        except KeyError:
            print("WARNING: QUALITY_CUT column not found in df_encoded, skipping drop.")  
    if drop_duplicates:              
        df_encoded=df_encoded.copy().drop_duplicates(subset=[col for col in df_encoded.columns if col not in ['DEFECT_TYPE', 'QUALITY_CUT']], keep='first', inplace=False).reset_index(drop=True)
    print("example: \n")
    print(f"total samples in df_encoded: {len(df_encoded)}")
    if "DEFECT_TYPE" in df_encoded.columns:
        print(f'total samples good in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]=="No Defects"])}')
        print(f'total samples bad in df_encoded: {len(df_encoded[df_encoded["DEFECT_TYPE"]!="No Defects"])}')
    if "TECHNOLOGY_GAS" not in columns:
        try:
            df_encoded=df_encoded.copy().drop(["TECHNOLOGY_GAS"],axis=1)
        except KeyError:
            print("WARNING: TECHNOLOGY_GAS column not found in df_encoded, skipping drop.")

    if "CONTOUR_LASER_MODE" not in columns:
        try:
                df_encoded=df_encoded.copy().drop(["CONTOUR_LASER_MODE"],axis=1)
        except KeyError:
                print("WARNING: CONTOUR_LASER_MODE column not found in df_encoded, skipping drop.")

    if "LASER_TYPE" not in columns:
        try:
            df_encoded=df_encoded.copy().drop(["LASER_TYPE"],axis=1)
        except KeyError:
            print("WARNING: LASER_TYPE column not found in df_encoded, skipping drop.")

    return df_encoded