import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import calibration.WT_calibration as cal
from mpl_toolkits.mplot3d import Axes3D

def read_csv_files(directory):
    """
    The read_csv_files function reads all the csv files present in a directory and its subdirectories, and returns a concatenated DataFrame of all the csv files. The function takes in one parameter:

        - directory: the path of the directory where the function will look for csv files
    The function uses the os.walk() method to traverse the directory tree starting from the directory path. The os.walk() method returns a generator that produces the current directory, a list of subdirectories and a list of files in the current directory.

    For each directory, file in the directory tree, the function checks if the file is a csv file by checking the file extension. If the file is a csv file, the function reads the file using pandas.read_csv() method and adds a column 'id' to the dataframe with the value of the directory name.

    The function then concatenates all the dataframes read from the csv files using the pandas.concat() method.

    Finally, the function returns the concatenated dataframe.

    This function allows you to read all the csv files present in a directory and its subdirectories and concatenate them into a single DataFrame, it also allows you to add a column 'id' to the DataFrame with the value of the directory name.
    """
    data_frames = []
    for subdir, dirs, files in os.walk(directory):
        for file in files:
            filepath = subdir + os.sep + file
            if filepath.endswith("data_export.csv"):
                df = pd.read_csv(filepath)
                folder_name = subdir.split("/")[-1] # get last folder name
                id = int(folder_name.split("_")[-1]) # get number after '_'
                df['id'] = id
                data_frames.append(df)
    return data_frames


def select_data(data_frames, ids=None, start_date=None, end_date=None, labels=None, features=None):
    """
    The select_data function is used to select a subset of rows and columns from a list of dataframes based on specified criteria such as ids, date range, labels, and features. The function takes in the following parameters:

        - data_frames: A list of dataframes from which to select the data.
        - ids: A list of ids to filter the dataframes by. Only the rows with the specified ids will be selected.
        - start_date: A date string in the format '%Y-%m-%d %H:%M:%S' to filter the dataframes by. Only the rows with a 'Date' column greater than or equal to the start_date will be selected.
        - end_date: A date string in the format '%Y-%m-%d %H:%M:%S' to filter the dataframes by. Only the rows with a 'Date' column less than or equal to the end_date will be selected.
        - labels: A list of columns to be used as labels.
        - features: A list of columns to be used as features.
    The function iterates through each dataframe, and applies the specified filters to select the rows. It also selects the columns based on the labels, features and id. Then it concatenates the selected rows
    """
    selected_data = pd.DataFrame()
    for df in data_frames:
        if "SWP" in features:
            df['SWP'] = df['Watermark'].apply(cal.double_calibration)
        if ids:
            if df['id'].iloc[0] not in ids:
                continue
        if start_date or end_date:
            df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d %H:%M:%S') #Convert date string to datetime
            if start_date and end_date:
                df = df[(df['Date'] >= start_date) & (df['Date'] <= end_date)]
            elif start_date:
                df = df[(df['Date'] >= start_date)]
            elif end_date:
                df = df[(df['Date'] <= end_date)]
        if labels or features:
            columns_to_keep = ["Date"]
            if labels:
                columns_to_keep += labels
            if features:
                columns_to_keep += features    
            columns_to_keep += ['id']
            drop_cols = set(df.columns) - set(columns_to_keep)
            if "day" in columns_to_keep:
                df['day'] = df['Date'].dt.day
            if "month" in columns_to_keep:
                df['month'] = df['Date'].dt.month
            if "hour" in columns_to_keep:
                df['hour'] = df['Date'].dt.hour
            #drop_cols.add("Date")
            df = df.drop(drop_cols, axis=1)
        selected_data = pd.concat([selected_data, df], ignore_index=True)
    return selected_data


def normalize_data(data, normalization_ids, columns_to_normalize, reference_ids, forced_norm_values=None):
    """
    The normalize_data function normalizes the data in a given DataFrame. The function takes in several parameters:

    - data: the DataFrame containing the data to be normalized
    - normalization_ids: a list of ids of the rows to be normalized
    - columns_to_normalize: a list of column names to be normalized
    - reference_ids: a list of ids of the rows to use for the calculation of the minimum and maximum values for normalization
    - forced_norm_values (optional): a dictionary of column names that maps the column name to a dictionary with keys 'min' and 'max', these keys allow to set the min and max value for normalization for that column, if it is not passed for a column the function will calculate it from data.
    
    The function first selects the rows to be normalized by filtering the DataFrame using the normalization_ids. Then it selects the rows to use for the calculation of the minimum and maximum values for normalization by filtering the DataFrame using the reference_ids.
    Next, the function checks if the forced_norm_values is passed, if it is not passed the function will calculate the min and max values for normalization from the data_to_reference, if it is passed, the function uses the values passed in the dictionary for columns that are present in the forced_norm_values otherwise it uses the min and max value from the data.
    Finally, the function normalizes the selected columns of the filtered DataFrame using the calculated minimum and maximum values and update the original DataFrame, it then drops the rows that have not been normalized. The updated DataFrame is then returned.

    This function allows you to normalize data using specific ids, it also allows you to pass in different minimum and maximum values for normalization for specific columns, if you don't pass these values for a column the function will use the min and max value from the data.
    """
    #data_to_normalize = data.loc[data['id'].isin(normalization_ids), columns_to_normalize]
    data_to_normalize = data.loc[:, columns_to_normalize]
    data_to_reference = data.loc[data['id'].isin(reference_ids), columns_to_normalize]
   
    if forced_norm_values is None:
        data_min = data_to_reference.min()
        data_max = data_to_reference.max()
    else:
        data_min = pd.Series(index=columns_to_normalize)
        data_max = pd.Series(index=columns_to_normalize)
        for column in columns_to_normalize:
            if column in forced_norm_values:
                data_min[column] = forced_norm_values[column]['min']
                data_max[column] = forced_norm_values[column]['max']
            else:
                data_min[column] = data_to_reference[column].min()
                data_max[column] = data_to_reference[column].max()
                
    #data.loc[data['id'].isin(normalization_ids), columns_to_normalize] = (data_to_normalize - data_min) / (data_max - data_min)
    data.loc[:, columns_to_normalize] = (data_to_normalize - data_min) / (data_max - data_min)
    
    # drop the rows that have not been normalized
    
    #data = data[data['id'].isin(normalization_ids)]
    # data.drop('id',inplace=True, axis=1)
    return data, data_min, data_max


def clip_data(df, clip_ranges):
    """
    Clip the values in the specified columns of a pandas dataframe to the specified ranges.

    Args:
        df (pandas.DataFrame): The input dataframe.
        clip_ranges (dict): A dictionary where the keys are the names of the columns to clip, and the values
                            are lists with two elements: the minimum and maximum values to clip to.

    Returns:
        pandas.DataFrame: A new dataframe with the clipped values.
    """
    df_clipped = df.copy()
    for col_name, (min_val, max_val) in clip_ranges.items():
        df_clipped[col_name] = df_clipped[col_name].clip(min_val, max_val)
    return df_clipped


def apply_math_functions(dataframe, functions_to_apply,features_list):

    functions_list = {
                        "square": lambda x: x**2,
                        "cube"  : lambda x: x**3
                }
    
    for function_target in functions_to_apply:
        for function_applied in functions_to_apply[function_target]:
            if function_applied in functions_list:
                modified_column_name = function_target+"_"+function_applied
                #columns.append(modified_column_name)
                dataframe[modified_column_name] = functions_list[function_applied](dataframe[function_target])
                features_list.append(modified_column_name)
    return dataframe


def add_delayed_columns(df,shift_columns,features):
    """
    This function takes a pandas dataframe and a number n, and adds columns that are delayed versions of the specified columns.
    """
    if shift_columns is not None:
        for col in shift_columns:
            if col in df.columns:
                n = shift_columns[col]
                for i in range(1, n+1):
                    delayed_col_name = col + '_delayed_' + str(i)
                    df[delayed_col_name] = df[col].shift(i)
                    features.append(delayed_col_name)

        df = df.dropna()
        return df


def plot_dataframe(df, columns, title=None, save_path=None):
    """The plot_dataframe function is a Python function that accepts a pandas DataFrame df and a list of column names columns to create subplots of the specified columns vs date. It has two optional parameters: title, which allows you to set a title for the entire figure, and save_path, which lets you save the figure to a file if specified.

The function first determines the number of subplots to create using the length of columns. It then creates a figure and an array of subplots using the subplots method of the matplotlib.pyplot module. The figsize argument sets the size of the entire figure, while the sharex argument ensures that all subplots share the same x-axis.

The function uses a list of colors to assign a unique color to each curve plotted in each subplot. If more than the number of predefined colors are needed, the function loops back to the first color.

For each subplot, the function uses the plot method of the subplot object to plot the date column of the DataFrame df against the specified column col. It also sets a title for each subplot using the name of the column. The x-axis format is set to display dates in the format '%Y-%m-%d' using the DateFormatter class of the matplotlib.dates module. The set_xticks([]) method is used to remove the x-axis ticks. Finally, the grid method is used to display a grid in the background of the subplot.

If a title is provided, the function uses the suptitle method of the figure object to set the title of the entire figure. If a save_path is provided, the function saves the figure to the specified file location using the savefig method of the pyplot module. If save_path is not specified, the function displays the figure using the show method of the pyplot module."""
    num_plots = len(columns)
    fig, axs = plt.subplots(num_plots, 1, figsize=(10, 5*num_plots),sharex=True)
    
    # Set up a list of colors to use for the different curves
    colors = ['red', 'green', 'blue', 'orange', 'purple', 'brown', 'pink', 'gray']
    
    for i, col in enumerate(columns):
       
        axs[i].plot(df['Date'], df[col], color=colors[i % len(colors)])
        axs[i].set_title(col)
        axs[i].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        axs[i].set_xticks([])
        axs[i].grid()
    if title:
        fig.suptitle(title)
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


def plot_points(df, x_col, color_col, y_col=None, z_col=None, title=None, save_path=None):
    """Plots points on a plane with x coordinates specified by a column of a dataframe.
    If y_col is specified, the y coordinate is taken from another column of the dataframe.
    If y_col is not specified, the plot will be one-dimensional.
    If z_col is specified, the plot will be in 3D with z coordinate taken from another column of the dataframe.
    The color of each point is specified by a column with values between 0 and 1.
    0 corresponds to red and 1 to green.
    The plot can be given a title and saved to a file if title and save_path arguments are passed."""
    
    x = df[x_col]
    c = df[color_col]
    
    fig = plt.figure()
    if z_col is not None:
        ax = fig.add_subplot(111, projection='3d')
        z = df[z_col]
        if y_col is not None:
            y = df[y_col]
            ax.scatter(x, y, z, c=np.array([c, 1-c, np.zeros_like(c)]).T, cmap='RdYlGn')
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_col)
            ax.set_zlabel(z_col)
            ax.set_title('Points colored by {}'.format(color_col))
        else:
            ax.scatter(x, np.zeros_like(x), z, c=np.array([c, 1-c, np.zeros_like(c)]).T, cmap='RdYlGn')
            ax.set_xlabel(x_col)
            ax.set_zlabel(z_col)
            ax.set_title('Points colored by {}'.format(color_col))
    elif y_col is not None:
        ax = fig.add_subplot(111)
        y = df[y_col]
        ax.scatter(x, y, c=np.array([c, 1-c, np.zeros_like(c)]).T, cmap='RdYlGn')
        ax.set_xlabel(x_col)
        ax.set_ylabel(y_col)
        ax.set_title('Points colored by {}'.format(color_col))
    else:
        ax = fig.add_subplot(111)
        ax.scatter(x, np.zeros_like(x), c=np.array([c, 1-c, np.zeros_like(c)]).T, cmap='RdYlGn')
        ax.set_xlabel(x_col)
        ax.get_yaxis().set_visible(False)
        ax.set_title('Points colored by {}'.format(color_col))
    
    if title:
        ax.set_title(title)
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

def plot_data(dataframe, columns, name):
    """
    Plot one or more columns specified by the user and save them in a folder.

    Parameters:
        - dataframe: The input dataframe.
        - columns: A list of columns to plot.
        - name: The name of the folder where the plots will be saved.

    Returns:
        - None
    """
    # Create a folder to save the plots
    os.makedirs(f'../Dataplots/{name}', exist_ok=True)

    for column in columns:
        if column not in dataframe.columns:
            print(f"Column {column} not found in the dataframe.")
            continue
        
        # Plot the data
        plt.plot(dataframe[column])
        plt.xlabel('Index')
        plt.ylabel(column)
        plt.title(f"Plot of {column}")

        # Save the plot
        plt.savefig(f'../Dataplots/{name}/{column}.png')
        plt.clf() # clear the figure for the next plot


#def plot_dataframes(dataframes, columns, legend_entries, save_path):
    """
    Plot specified columns from a list of dataframes and save the resulting image in a specified directory.

    Parameters:
        - dataframes: A list of pandas dataframes to plot.
        - columns: A string representing the name of a column to plot from each dataframe.
        - legend_entries: A list of strings to use as legend entries.
        - save_path: The path to save the resulting image.

    Returns:
        - None
    """
    # Create the directory structure if it does not exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    # Turn off display of figures on the screen
    plt.ioff()
    # Create a new figure and axes
    plt.figure()
    for i, (df,legend) in enumerate(zip(dataframes,legend_entries)):
        df[columns].plot(label =legend)
    # Add legend
    plt.legend()
    # Save the resulting image
    plt.savefig(save_path)
    plt.close() # close 

