import math
import os
import pandas as pd
from Configs import Configs
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from mycolorpy import colorlist as mcp
import matplotlib.cm as cm

class Utility:
    def get_dataframe_from_excel(self, uri):
        df = pd.read_excel(uri, sheet_name=None)
        return df

    @staticmethod
    def calculate_distance(point1, point2):
        (x1, y1, z1) = point1
        (x2, y2, z2) = point2
        distance = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
        return distance

    def get_files_in_directory(self, directory):
        file_list = []
        for root, directories, files in os.walk(directory):
            for file in files:
                file_list.append(os.path.join(root, file))
        return file_list

    def get_test_train_from_dataframe(self, data, labels):
        return train_test_split(data, labels, test_size=0.25, random_state=80)

    def get_configs(self, is_pd=False):
        if is_pd:
            file_paths = self.get_files_in_directory("json/PD/")
            jointPositionTags = 'joint_positions'
            output_extension = '_parkinson'
        else:
            file_paths = self.get_files_in_directory("json/HEALTHY/")
            jointPositionTags = 'joints_position'
            output_extension = '_healthy'

        frames = 'frames'
        hands = 'hands'
        timestamp_usec = 'timestamp_usec'
        timestamp = 'timestamp'
        distance = 'Distance'
        velocity = 'Velocity'
        Combined = 'Combined'
        img_output_dir = 'Imgs'
        is_pd = "is_pd"
        return Configs(file_paths, jointPositionTags, frames, hands, timestamp_usec, timestamp, distance, velocity,
                       Combined,output_extension, img_output_dir, is_pd)

    def save_graphs(self, df, configs, json_file_path):
        json_file_name = os.path.basename(json_file_path)

        # Create folders if they don't exist
        modified_uri = json_file_path.replace('\\', '/')
        modified_uri = modified_uri.replace('.json', '')

        distance_dir = os.path.join(configs.img_output_dir, modified_uri, configs.distance)
        velocity_dir = os.path.join(configs.img_output_dir, modified_uri, configs.velocity)
        distance_dir = distance_dir.replace('\\', '/')
        velocity_dir = velocity_dir.replace('\\', '/')
        os.makedirs(distance_dir, exist_ok=True)
        os.makedirs(velocity_dir, exist_ok=True)

        # Plot and save line chart for distance
        distance_chart = plt.figure()
        plt.plot(df[configs.timestamp], df[configs.distance])
        plt.xlabel('Timestamp (μs)')
        plt.ylabel('Distance (mm)')
        #plt.title('Distance vs Timestamp')
        distance_chart_path = os.path.join(distance_dir, 'distance_chart' + json_file_name + '.png')
        plt.savefig(distance_chart_path)
        plt.close(distance_chart)

        # Plot and save line chart for velocity
        velocity_chart = plt.figure()
        plt.plot(df[configs.timestamp], 1000*df[configs.velocity])
        plt.xlabel('Timestamp (μs)')
        plt.ylabel('Velocity (m/s)')
        #plt.title('Velocity vs Timestamp')
        velocity_chart_path = os.path.join(velocity_dir, 'velocity_chart' + json_file_name + '.png')
        plt.savefig(velocity_chart_path)
        plt.close(velocity_chart)

    def create_feature_scattered_graph(self, df, column_indices, column_names, color_column_value, legend_values,
                                       is_color_bar_visible, plot_title):
        # Extract other columns as the values for the other axis
        selected_data = df.iloc[:, column_indices]
        # Normalize color_column_value to be in the range [0, 1]
        norm = plt.Normalize(color_column_value.min(), color_column_value.max())

        # Use a colormap to map normalized values to colors
        cmap = plt.cm.bwr
        mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)

        # Create a 3D scatter plot with colors based on color_column_value
        fig = plt.figure(figsize=(10, 6))
        ax = fig.add_subplot(111, projection='3d')
        proxy_artists = []  # List to store proxy artists
        for i, column_index in enumerate(column_indices):
            x = selected_data.iloc[:, 0]
            y = selected_data.iloc[:, 1]
            z = selected_data.iloc[:, 2]

            # Get the color for each point based on color_column_value
            colors = mappable.to_rgba(color_column_value)
            # Use a proxy artist to create a custom legend
            proxy = plt.Line2D([0], [0], linestyle='none', c=colors[i], marker='o')
            proxy_artists.append(proxy)
            ax.scatter(x, y, z, c=colors, label=f'Column {column_index}')

        if column_names is None:
            ax.set_xlabel(df.columns[column_indices[0]])
            ax.set_ylabel(df.columns[column_indices[1]])
            ax.set_zlabel(df.columns[column_indices[2]])
        elif len(column_names) == len(column_indices):
            ax.set_xlabel(column_names[0])
            ax.set_ylabel(column_names[1])
            ax.set_zlabel(column_names[2])

        if is_color_bar_visible is not None and is_color_bar_visible:
            # Add colorbar
            cbar = plt.colorbar(mappable, ax=ax, pad=0.1)
            cbar.set_label('Color Column Value')

        ax.legend(handles=proxy_artists, labels=legend_values)
        plt.rcParams['axes.labelpad'] = 10
        #plt.title(plot_title)
        plt.show()
