import numpy as np
import pandas as pd
from scipy.spatial import distance
import constants
from ortools.constraint_solver import routing_enums_pb2
from ortools.constraint_solver import pywrapcp


def create_distance_matrix(customers_selected: pd.DataFrame, base_info: dict, metric: str = 'euclidean'):
    # quando si itera su più giorni la posizione dell'ordine nella matrice di distanza è diverso dall'index, per questo
    # si crea un dizionario per mappare indice e posizione nella matrice
    keys_list = customers_selected['index']
    values_list = list(range(1, len(keys_list)+1))
    a_dictionary = dict(zip(keys_list, values_list))
    a_dictionary[0] = 0  # aggiunge id del deposito

    # Crea array di coordinate (deposito e clienti selezionati)
    coords = np.vstack((base_info['depot_coord'], customers_selected[['x', 'y']]))
    # Calcola la matrice delle distanze tra punti
    distance_matrix = distance.cdist(coords, coords, metric)
    distance_matrix = np.round(distance_matrix, 3)

    return distance_matrix, a_dictionary


def ortools_solver(customers_selected: pd.DataFrame, base_info: dict, precision: int = 1000,
                   printSolution: bool = False, time: int = 1):
    distance_matrix, _ = create_distance_matrix(customers_selected, base_info)
    # la moltiplicazione per "precision" è dovuta al fatto che ortools lavora con interi
    distance_matrix = distance_matrix * precision
    demands = np.hstack(([0], customers_selected['kg']))
    index_depot = 0
    if "max_load" not in base_info:
        base_info['max_load'] = constants.VEHICLE_CAPACITY

    # Create the routing index manager.
    manager = pywrapcp.RoutingIndexManager(len(distance_matrix),
                                           constants.NUM_VEHICLES, index_depot)

    # Create Routing Model.
    routing = pywrapcp.RoutingModel(manager)

    # Create and register a transit callback.
    def distance_callback(from_index, to_index):
        """Returns the distance between the two nodes."""
        # Convert from routing variable Index to distance matrix NodeIndex.
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distance_matrix[from_node][to_node]

    transit_callback_index = routing.RegisterTransitCallback(distance_callback)

    # Define cost of each arc.
    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

    # Add Capacity constraint.
    def demand_callback(from_index):
        """Returns the demand of the node."""
        # Convert from routing variable Index to demands NodeIndex.
        from_node = manager.IndexToNode(from_index)
        return demands[from_node]

    demand_callback_index = routing.RegisterUnaryTransitCallback(
        demand_callback)

    routing.AddDimensionWithVehicleCapacity(
        demand_callback_index,
        0,  # null capacity slack
        [base_info['max_load']] * constants.NUM_VEHICLES,  # vehicle maximum capacities
        True,  # start cumul to zero
        'Capacity')

    # Setting first solution heuristic.
    search_parameters = pywrapcp.DefaultRoutingSearchParameters()
    search_parameters.first_solution_strategy = (
        routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC)
    search_parameters.local_search_metaheuristic = (
        routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH)
    search_parameters.time_limit.FromSeconds(time)

    # Solve the problem.
    solution = routing.SolveWithParameters(search_parameters)

    # Print solution on console.
    if solution and printSolution:
        print_solution(manager, routing, solution, demands, distance_matrix, precision)

    return solution.ObjectiveValue()/precision, precision


def print_solution(manager, routing, solution, demands, distance_matrix, precision):
    """Prints solution on console."""
    print(f'Objective: {solution.ObjectiveValue() / precision}')
    total_distance = 0
    total_load = 0
    count_vehicle = 0
    obj_value = 0
    for vehicle_id in range(constants.NUM_VEHICLES):
        index = routing.Start(vehicle_id)
        plan_output = f'Route for vehicle {count_vehicle}:\n'
        route_distance = 0
        route_load = 0

        while not routing.IsEnd(index):
            node_index = manager.IndexToNode(index)
            route_load += demands[node_index]
            plan_output += f' {node_index} Load({route_load}) -> '
            previous_index = index
            index = solution.Value(routing.NextVar(index))
            route_distance += routing.GetArcCostForVehicle(previous_index, index, vehicle_id)
            obj_value += distance_matrix[manager.IndexToNode(previous_index)][manager.IndexToNode(index)]
        plan_output += f' {manager.IndexToNode(index)} Load({route_load})\n'
        plan_output += f'Distance of the route: {route_distance / precision}km\n'
        plan_output += f'Load of the route: {route_load}\n'
        if route_load != 0:
            print(plan_output)
            count_vehicle += 1
        total_distance += route_distance
        total_load += route_load
    print(f'Total distance of all routes: {total_distance / precision}km')
    print(f'Total load of all routes: {total_load}')
    print(f'Total demands: {sum(demands)}')
    print(f'Total vehicle: {count_vehicle}')

    # calcolo preciso delle distanze usando la distance matrix
    obj_value = round(obj_value / precision, 3)
    print(obj_value)
