import random
import numpy as np
import constants
from OrTools_solver import create_distance_matrix
from Customer import Customer
from Vehicle import Vehicle


class TabuSearch:

    def __init__(self, customers, base_info):
        self.depot = Customer(id=0, demand=0, service_time=0,
                              coord_x=base_info['depot_coord'][0], coord_y=base_info['depot_coord'][0])
        self.V = constants.NUM_VEHICLES
        self.C = base_info['max_load']
        self.N = len(customers)
        self.T = constants.MAX_TIME

        self.customers = {customer['index']: Customer(id=customer['index'], demand=customer['kg'],
                                                      service_time=customer['service_time'],
                                                      coord_x=customer['x'], coord_y=customer['y'])
                          for idx, customer in customers.iterrows()}
        self.vehicle = [Vehicle(i + 1, self.C) for i in range(self.V)]
        self.distance_matrix, self.map_matrix = create_distance_matrix(customers, base_info)
        self.TABU = {}
        self.best_solution = None
        self.best_cost = None
        self.best_candidate = None
        self.candidate_cost = None
        self.best_costs = []
        self.costs = []
        self.processed_solution = {}
        self.best_known_solution = None
        if "best_known_solution" in base_info:
            self.best_known_solution = base_info['best_known_solution']

    def generate_random_solution(self, seed: int = None):
        routes = [[0] for _ in range(self.V)]
        # settato seed per riproducibilità
        random.seed(seed)
        samples = random.sample(list(self.customers.values()), k=len(self.customers))
        demands = [customer.demand for customer in samples]
        tot_to_load = sum(demands)

        for p in routes:
            if tot_to_load > self.C:
                to_load = 0
                while to_load + demands[-1] <= self.C:
                    demand = demands.pop()
                    to_load += demand
                    p.append(samples.pop().id)
                tot_to_load -= to_load
            else:
                p.extend([samples.pop().id for _ in range(len(samples))])
            p.append(0)

        return routes

    def solution_feasible(self, routes: list) -> bool:
        for route in routes:
            route_demand = sum(self.customers.get(i).demand for i in route[1:-1])
            if route_demand > self.C:
                return False
        return True

    def solution_cost(self, solution):
        routes_cost = [sum(self.distance_matrix[self.map_matrix[route[i]]][self.map_matrix[route[i + 1]]] for i in
                           range(len(route) - 1))
                       for route in solution]
        return sum(routes_cost)

    def initialize_solution(self, solution):
        self.best_solution = solution
        self.best_cost = self.solution_cost(solution)
        self.best_candidate = solution
        self.candidate_cost = self.best_cost
        self.costs.append(self.best_cost)
        self.best_costs.append(self.best_cost)
        self.processed_solution = {}

    def find_neighborhood(self):
        solution = [path[:] for path in self.best_candidate]

        neighborhood = []
        moves = []
        for id1 in range(self.V):
            for id2 in range(id1, self.V):
                # swap tra due percorsi
                if len(solution[id2]) > 2:
                    sol_copy = [path[:] for path in solution]
                    sampl1 = np.random.randint(1, len(solution[id1]) - 1)
                    sampl2 = np.random.randint(1, len(solution[id2]) - 1)
                    sol_copy[id1][sampl1], sol_copy[id2][sampl2] = sol_copy[id2][sampl2], sol_copy[id1][sampl1]
                    if self.solution_feasible(sol_copy):
                        moves.append((id1, sampl1, id2, sampl2))
                        neighborhood.append(sol_copy)

                # swap nel primo percorso vuoto
                elif len(solution[id1]) > 2 and len(solution[id2 - 1]) > 2:
                    sol_copy = [path[:] for path in solution]
                    sampl1 = np.random.randint(1, len(solution[id1]) - 1)
                    sol_copy[id2].insert(1, sol_copy[id1][sampl1])
                    del sol_copy[id1][sampl1]
                    # se rimane un percorso vuoto in mezzo lo sposta alla fine
                    if len(sol_copy[id1]) == 2:
                        empty_route = sol_copy.pop(id1)
                        sol_copy.append(empty_route)
                    if self.solution_feasible(sol_copy):
                        moves.append((id1, sampl1, id2, -1))
                        neighborhood.append(sol_copy)

                # toglie un percorso con solo un elemento
                if len(solution[id1]) > 2 and len(solution[id2]) == 3:
                    sol_copy = [path[:] for path in solution]
                    sampl1 = np.random.randint(1, len(solution[id1]))  # campiona dove inserire l'elemento
                    sol_copy[id1].insert(sampl1, sol_copy[id2][1])
                    del sol_copy[id2][1]
                    # se rimane un percorso vuoto in mezzo lo sposta alla fine
                    if len(sol_copy[id2]) == 2:
                        empty_route = sol_copy.pop(id2)
                        sol_copy.append(empty_route)
                    if self.solution_feasible(sol_copy):
                        moves.append((id1, sampl1, id2, -2))
                        neighborhood.append(sol_copy)
        return neighborhood, moves

    def search(self, n_iters=5000, tabu_size=50, print_solution=False, seed: int = None):
        sol = self.generate_random_solution(seed=seed)
        self.initialize_solution(sol)
        if print_solution:
            print('Cost: ' + str(self.best_cost))
            print('Solution:')
            for i in range(len(sol)):
                if len(sol[i]) > 2:
                    print(i, ' ', sol[i])

        # se viene chiamato l'algoritmo quando non ci sono clienti da servire
        if self.N == 0:
            n_iters = 0

        # settato seed per riproducibilità
        np.random.seed(seed)

        while n_iters > 0:
            neighborhood, moves = self.find_neighborhood()

            for neighbor, move in zip(neighborhood, moves):
                if move not in self.TABU:
                    self.best_candidate = neighbor
                    self.candidate_cost = self.solution_cost(self.best_candidate)
                    temp_move = move
                    break

            for neighbor, move in zip(neighborhood[1:], moves[1:]):
                neighbor_cost = self.solution_cost(neighbor)
                if neighbor_cost < self.candidate_cost:
                    if move not in self.TABU:
                        self.best_candidate = neighbor
                        self.candidate_cost = neighbor_cost
                        temp_move = move
                    else:
                        # Aspiration criteria
                        if neighbor_cost < self.best_cost:
                            self.best_candidate = neighbor
                            self.candidate_cost = neighbor_cost
                            temp_move = move

            if self.candidate_cost < self.best_cost:
                self.best_solution = self.best_candidate
                self.best_cost = self.candidate_cost

            # Add candidate to tabu and update cost history
            self.TABU[temp_move] = tabu_size
            self.costs.append(self.candidate_cost)
            self.best_costs.append(self.best_cost)

            # Remove moves from tabu list
            moves_to_delete = []
            for move, i in self.TABU.items():
                if i == 0:
                    moves_to_delete.append(move)
                else:
                    self.TABU[move] -= 1

            for s in moves_to_delete:
                del self.TABU[s]

            n_iters -= 1
            if n_iters % 100 == 0 and print_solution:
                print(self.best_cost)

        if print_solution:
            print(self.best_cost)
            print(self.best_solution)
            for i in range(len(self.best_solution)):
                if len(self.best_solution[i]) > 2:
                    print(i, ' ', self.best_solution[i])

            print("Best_known_solution " + str(self.best_known_solution))
        return self.best_cost
