import math
import numpy as np
import bezier
import warnings
import folium
import datetime
from router_function import router
from scipy import interpolate


class Path():
    
    def __init__(self, source_geo, dest_geo, filename): 
        self.source_geo = source_geo
        self.dest_geo = dest_geo
        self.file = filename
        self.routeLatLons = router(self.source_geo, self.dest_geo, self.file)


    def increase_density(self):
        # Calculate distances
        distances = [math.dist((x0, y0), (x1, y1)) for (x0, y0), (x1, y1) in zip(self.routeLatLons, self.routeLatLons[1:])]
        minimum = min(distances)

        # Precompute range values
        ranges = [range(math.floor(dist / minimum)) for dist in distances]

        smoothed_x, smoothed_y = [], []
        for (x0, y0), (x1, y1), r in zip(self.routeLatLons, self.routeLatLons[1:], ranges):
            smoothed_x.extend(np.linspace(x0, x1, len(r)))
            smoothed_y.extend(np.linspace(y0, y1, len(r)))

        self.routeLatLons = list(zip(smoothed_x, smoothed_y))
         
        return self.routeLatLons
    

    def increase_density_fixed(self, density=5):
        smoothed_x, smoothed_y = [], []

        for (x0, y0), (x1, y1) in zip(self.routeLatLons, self.routeLatLons[1:]):
            smoothed_x.extend([x0 + k * (x1 - x0) / density for k in range(1, density)])
            smoothed_y.extend([y0 + k * (y1 - y0) / density for k in range(1, density)])

        smoothed_x.append(self.routeLatLons[-1][0])
        smoothed_y.append(self.routeLatLons[-1][1])
        self.routeLatLons = list(zip(smoothed_x, smoothed_y))
        
        return self.routeLatLons

    def dir_traj(self, pt, conv=False):
    
        try:
            ind = self.routeLatLons.index(pt)
            if ind == 0:
                # print(f"{pt} is the first point of the trajectory, no correction to the yaw is needed")
                return 0

            prev_point = self.routeLatLons[ind - 1]
            next_point = self.routeLatLons[ind + 1] if ind + 1 < len(self.routeLatLons) else self.routeLatLons[ind]

            delta_x1, delta_y1 = self.routeLatLons[ind][0] - prev_point[0], self.routeLatLons[ind][1] - prev_point[1]
            prev_angle = math.atan2(delta_y1, delta_x1)

            delta_x2, delta_y2 = next_point[0] - self.routeLatLons[ind][0], next_point[1] - self.routeLatLons[ind][1]
            next_angle = math.atan2(delta_y2, delta_x2)

            dir_angle = prev_angle - next_angle

            if conv:
                return math.degrees(dir_angle)

            return dir_angle

        except ValueError:
            # print(f"{pt} does not belong to the list of coordinates, no correction to the yaw can be applied")
            return 0

        except IndexError:
            # print(f"{pt} is the last point of the trajectory, no correction to the yaw is needed")
            return 0


    def smooth(self, alpha=0.2, accuracy=7):
        # Check on alpha
        if alpha < 0:
            warnings.warn("WARNING: the selected value for alpha could not ba accepted. It has been set to 0 by default")
            alpha = 0
        elif alpha > 0.5:
            warnings.warn("WARNING: the selected value for alpha could not ba accepted. It has been set to 0.5 by default")
            alpha = 0.5

        # Initialization of the smoothed list
        smoothed_path = []

        # Calculate the control points for the Bezier curve (only if the angular deviation is higher then pi/4)      
        for i in range(len(self.routeLatLons)-1):
            
            # Anyway append every point of the original list into the smoothed one
            smoothed_path.append(self.routeLatLons[i])
            
            # Check whether smoothing is necessary
            if self.dir_traj(self.routeLatLons[i]) > np.pi/6:
                smoothed_path.remove(self.routeLatLons[i])
                x0, y0 = self.routeLatLons[i-1]
                x1, y1 = self.routeLatLons[i]
                x2, y2 = self.routeLatLons[i+1]

                cx1 = x1 + alpha * (x2 - x1)
                cy1 = y1 + alpha * (y2 - y1)
                cx0 = x1 - alpha * (x1 - x0)
                cy0 = y1 - alpha * (y1 - y0)

                # Build the Bezier curve and evaluate its coordinates
                nodes = np.array([(cx0, cy0), (x1,y1),(cx1, cy1)]).T
                curve = bezier.Curve.from_nodes(nodes)
                x_vals = np.linspace(0, 1, accuracy)
                y_vals = curve.evaluate_multi(x_vals)
            
                # Append the self.routeLatLons of the smoothed path
                for k in range(y_vals.shape[1]):
                    smoothed_path.append((y_vals[0][k],y_vals[1][k]))
            
        # Remove duplicates by converting the list into a dict (which by definition cannot have duplicates) and then back to a list 
        smoothed_path = list(dict.fromkeys(smoothed_path))
        self.routeLatLons = smoothed_path
            
        return self.routeLatLons
    

    """ def spline(self, accuracy):
        
        x = [self.routeLatLons[i][0]*10 for i in range(len(self.routeLatLons))]
        y = [self.routeLatLons[i][1]*10 for i in range(len(self.routeLatLons))]
        x = list(dict.fromkeys(x))
        y = list(dict.fromkeys(y))
        spline = interpolate.interp1d(x,y,'cubic')
        xnew = np.linspace(min(x),max(x), accuracy)
        ynew = spline(xnew)

        self.routeLatLons = [(xnew[i]*(1/10),ynew[i]*(1/10)) for i in range(len(ynew))]
        
        return self.routeLatLons
 """

    def plot(self, name, zoom_start=20, save=True, plot_circles=False):
        # Interactive map creation
        my_map = folium.Map(location=self.routeLatLons[0], zoom_start=zoom_start)

        # Add circles to indicate the nodes
        if plot_circles: 
            for coord in self.routeLatLons:
                folium.CircleMarker(
                location=[coord[0], coord[1]],
                radius=2,
                color="blue",
                stroke=False,
                fill=True,
                fill_opacity = 10
                ).add_to(my_map)

        # Add a PolyLine to the map using the coordinates
        folium.PolyLine(locations=self.routeLatLons, color='blue', weight=1).add_to(my_map)

        # Save the file
        if save: 
            # Display the map
            format = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
            my_map.save(f'{name}_{format}')  # Save the map to an HTML file


    def closest_point(self, center, buffer_radius):
        closest_point = None
        min_distance = float('inf')

        for point in self.routeLatLons:
            distance = math.dist(point,center)

            if distance <= buffer_radius and distance < min_distance:
                min_distance = distance
                closest_point = point

        return closest_point
    
    