from typing import Tuple, Optional
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import constants
from scipy.spatial import distance
import os


def create_distribution(x_min: float = 0, x_max: float = 200,
                        y_min: float = 0, y_max: float = 200,
                        length: float = 5, height: float = 5,
                        num_city: int = 5, dim_pop: Optional[list] = None,
                        print_scatter: bool = False, print_heat_map: bool = False,
                        depot_coord: Optional[list] = None, seed: Optional[int] = None,
                        min_density: int = 8) \
        -> Tuple[pd.DataFrame, dict, np.array]:
    np.random.seed(seed)
    # check the dimension of population
    if dim_pop is None:
        dim_pop = (10 ** (np.random.rand(num_city) * 3 + 3)).astype(int)
    elif num_city != len(dim_pop):
        raise Exception('The length of dim_pop must be equal to num_city or dim_pop must be None')

    if depot_coord is None:
        depot_coord = np.array([np.random.rand(1)[0] * (x_max - x_min) + x_min,
                                np.random.rand(1)[0] * (y_max - y_min) + y_min])

    if print_scatter:
        plt.figure()
        plt.axis([x_min, x_max, y_min, y_max])
        plt.gca().set_aspect('equal', adjustable='box')

    x = np.array([])
    y = np.array([])
    for i in range(num_city):
        log_pop2 = np.log(dim_pop[i]) * 2
        mu_x = np.random.rand(1) * (x_max - x_min - log_pop2) + x_min + log_pop2 / 2
        mu_y = np.random.rand(1) * (y_max - y_min - log_pop2) + y_min + log_pop2 / 2
        mu = np.array([mu_x[0], mu_y[0]])
        cov = np.array([[log_pop2, 0], [0, log_pop2]])
        r = np.random.multivariate_normal(mean=mu, cov=cov, size=dim_pop[i])
        x = np.concatenate((x, r[:, 0]))
        y = np.concatenate((y, r[:, 1]))
        if print_scatter:
            plt.scatter(r[:, 0], r[:, 1], s=1)

    if print_scatter:
        plt.scatter(depot_coord[0], depot_coord[1], s=30, c='r', marker='+')
        plt.show()

    num_cell_x = int((x_max - x_min) / length)
    num_cell_y = int((y_max - y_min) / height)

    # calcolo nuovamente lunghezza e altezza delle celle perché nel calcolo del numero delle celle
    # potrebbero essere state modificate a causa di arrotondamento (visto che num_cell_x e num_cell_y
    # devono essere due numeri interi)
    length = (x_max - x_min) / float(num_cell_x)
    height = (y_max - y_min) / float(num_cell_y)

    grid_x = np.linspace(x_min, x_max, num_cell_x + 1)
    grid_y = np.linspace(y_min, y_max, num_cell_y + 1)

    grid, _, _ = np.histogram2d(x, y, bins=[grid_x, grid_y])
    # aggiunta di una densità minima di "density" abitanti per km quadrato così da avere ordini anche
    # al di fuori delle città
    grid = grid + min_density * length * height
    sum_grid = sum(sum(grid))
    grid = grid / sum_grid

    if print_heat_map:
        plt.figure()
        plt.axis([x_min, x_max, y_min, y_max])
        plt.gca().set_aspect('equal', adjustable='box')
        # la matrice grid viene ribaltata in automatico quindi l'elemento [0,0] nel grafico sarà
        # in basso a sinistra (quindi con coordinate (0;0))
        plt.pcolormesh(grid_x, grid_y, grid, cmap='jet')
        plt.scatter(depot_coord[0], depot_coord[1], s=200, c='r', marker='+')
        plt.colorbar()
        plt.show()

    base_info = {'x_min': x_min, 'x_max': x_max,
                 'y_min': y_min, 'y_max': y_max,
                 'num_cell_x': num_cell_x, 'num_cell_y': num_cell_y,
                 'length': length, 'height': height,
                 'depot_coord': depot_coord, 'seed': seed,
                 'max_load': constants.VEHICLE_CAPACITY}

    data_lines = {'id': [], 'x': [], 'y': [], 'probability': [], 'x_center': [], 'y_center': []}
    id_cell = 0
    for y in range(num_cell_y):
        y_cell = y_min + y * length
        for x in range(num_cell_x):
            x_cell = x_min + x * length
            # assign each value to the corresponding column (cell numeration starting from 0)
            data_lines['id'].append(id_cell)
            data_lines['x'].append(float(x_cell))
            data_lines['y'].append(float(y_cell))
            data_lines['probability'].append(grid[y][x])
            data_lines['x_center'].append(float(x_cell)+length/2)
            data_lines['y_center'].append(float(y_cell)+height/2)
            id_cell += 1
    # create dataframe
    df = pd.DataFrame(data_lines)

    return df, base_info, grid


def print_heat_map_and_orders(base_info: dict, distribution, customers_simulated, day: int = None):
    grid_x = np.linspace(base_info['x_min'], base_info['x_max'], base_info['num_cell_x'] + 1)
    grid_y = np.linspace(base_info['y_min'], base_info['y_max'], base_info['num_cell_y'] + 1)
    plt.figure()
    plt.axis([base_info['x_min'], base_info['x_max'], base_info['y_min'], base_info['y_max']])
    plt.gca().set_aspect('equal', adjustable='box')
    # plot distribuzione
    plt.pcolormesh(grid_x, grid_y, distribution, cmap='jet')
    # plot ordini
    plt.scatter(customers_simulated['x'], customers_simulated['y'], s=30, c='r', marker='o')
    # plot deposito
    plt.scatter(base_info['depot_coord'][0], base_info['depot_coord'][1], s=200, c='g', marker='+')
    plt.colorbar()
    # se c'è il giorno aggiunge il titolo con l'informazione del giorno
    if day is not None:
        plt.title('Day ' + str(day))
    plt.show()


def compatible_cells(df_distribution: pd.DataFrame, base_info, rho: float = 0.4):
    num_cell = len(df_distribution)
    coords = np.vstack((base_info['depot_coord'], df_distribution[['x_center', 'y_center']]))
    # Calcola la matrice delle distanze tra punti
    metric = 'euclidean'
    distance_matrix = distance.cdist(coords, coords, metric)
    index_compatibility = np.zeros((num_cell, num_cell))
    list_compatible_cells = []
    for i in range(num_cell):
        for j in range(num_cell):
            if j != i:
                # calcola l'indice di compatibilità di due celle, quindi un saving index dell'inclusione della cella i
                # e della cella j nella stessa route
                index_compatibility[i][j] = (distance_matrix[0][i + 1] + distance_matrix[0][j + 1] -
                                             distance_matrix[i + 1][j + 1]) \
                                            / (2 * (distance_matrix[0][i + 1] + distance_matrix[0][j + 1]))
            else:
                # i clienti nella stessa cella sono sempre compatibili
                index_compatibility[i][j] = 0.5
        # seleziona le celle convenienti in base ad un threshold rho
        list_compatible_cells.append(np.where(index_compatibility[i] > rho)[0])

    return index_compatibility, list_compatible_cells


def load(path: str, instance_name: str) -> Tuple[pd.DataFrame, dict]:
    # Check correctness of the file path
    if not os.path.exists(path + instance_name):
        raise Exception("Path and instance don't exist")
    # Read vehicle capacity
    with open(path + instance_name) as fp:
        for i, line in enumerate(fp):
            if i == 1:
                best = line.split()[-1][:-1]
                best_known_solution = int(float(best))
            elif i == 3:
                n_orders = int(line.split()[2])
            elif i == 5:
                max_load = int(line.split()[2])
    fp.close()

    new_orders = {'x': [], 'y': [], 'kg': [], 'service_time': [],
                  'last_day': [], 'yet_postponed': [], 'cell_id': [], 'index': []}

    if "Vrp-Set-A" in path:
        num_skiprows = 6
    elif "Vrp-Set-Li" in path:
        num_skiprows = 7
    else:
        raise Exception("This set of VRP instances has not been implemented")

    # Read nodes from txt file
    df_position = pd.read_csv(
        path + instance_name,
        sep="\t",
        skiprows=num_skiprows,
        nrows=n_orders,
    )
    # Scan each line of the file and add nodes to the network
    for line in df_position.itertuples():
        values = line[1].split()
        new_orders['x'] += [float(values[1])]
        new_orders['y'] += [float(values[2])]
        new_orders['index'] += [int(values[0])]
    # Read demand from txt file
    df_demand = pd.read_csv(
        path + instance_name,
        sep="\t",
        skiprows=range(num_skiprows+1 + n_orders),
        nrows=n_orders,
    )
    for line in df_demand.itertuples():
        values = line[1].split()
        new_orders['kg'] += [float(values[1])]

    new_orders['cell_id'] += [-1] * n_orders
    new_orders['last_day'] += [-1] * n_orders
    new_orders['yet_postponed'] += [False] * n_orders
    new_orders['service_time'] += [-1] * n_orders

    if "Vrp-Set-A" in path:
        x_min = 0
        x_max = 100
        y_min = 0
        y_max = 100
    elif "Vrp-Set-Li" in path:
        x_min = min(new_orders['x'])
        x_max = max(new_orders['x'])
        y_min = min(new_orders['y'])
        y_max = max(new_orders['y'])
    else:
        raise Exception("This set of VRP instances has not been implemented")

    length = 5
    height = 5

    num_cell_x = int((x_max - x_min) / length)
    num_cell_y = int((y_max - y_min) / height)

    # calcolo nuovamente lunghezza e altezza delle celle perché nel calcolo del numero delle celle
    # potrebbero essere state modificate a causa di arrotondamento (visto che num_cell_x e num_cell_y
    # devono essere due numeri interi)
    length = (x_max - x_min) / float(num_cell_x)
    height = (y_max - y_min) / float(num_cell_y)

    depot_coord = [new_orders['x'][0], new_orders['y'][0]]
    seed = 202

    base_info = {'x_min': x_min, 'x_max': x_max,
                 'y_min': y_min, 'y_max': y_max,
                 'num_cell_x': num_cell_x, 'num_cell_y': num_cell_y,
                 'length': length, 'height': height,
                 'depot_coord': depot_coord, 'seed': seed,
                 'best_known_solution': best_known_solution, 'max_load': max_load}

    return pd.DataFrame(new_orders), base_info
