import numpy as np
import networkx as nx



def prod(val):  
    res = 1 
    for ele in val:  
        res *= ele  
    return res


class ScenarioTree(nx.DiGraph):
    '''
    Questa classe definisce un albero di scenari
    Estende la classe nx.Digraph.
    '''
    def __init__(self, name: str, branching_factors: list, dim_observations: int, initial_value: np.ndarray , mean: np.ndarray , cov: np.ndarray, seed: int):
        """Inizializzazione dell'albero
        Input:
            name (str): nome dell'albero
            branching_factors (list): descrive i fattori di ramificazione dell'albero
            dim_observations (int): dimensione di ogni osservazione
            initial_value (np.darray): valore osservato al nodo radice
            mean (np.ndarray): valore atteso della distribuzione normale multivariata
            cov (np.ndarray): matrice di varianza e covarianza della distribuzione normale multivariata
            seed (int): seed per il campionamento
        """
        nx.DiGraph.__init__(self)
        self.starting_node = 0
        self.dim_observations = dim_observations
        self.mean = mean
        self.cov= cov
        self.seed=seed
        # Nodo radice
        self.add_node(
            self.starting_node,
            obs=initial_value,
            prob=1,
            t=0,
            id=0,
            stage=0
        )
        self.name = name
        self.breadth_first_search = []
        self.depth = len(branching_factors)
        self.branching_factors = branching_factors
        # Calcolo del numero di scenari totale
        self.n_scenarios = prod(self.branching_factors)
        count = 1
        last_added_nodes = [self.starting_node]
        n_nodes_per_level = 1
        
        np.random.seed(seed)
        # Generazione degli altri nodi
        for i in range(self.depth):
            next_level = []
            n_nodes_per_level *= self.branching_factors[i]
            # per ogni nodo 'genitore' si aggiungono nodi 'figli'
            for parent_node in last_added_nodes:
                # probabilità uniforme ad ogni ramificazione
                probs=1/self.branching_factors[i]
                for j in range(self.branching_factors[i]):
                    id_new_node = count
                    
                    # se fattore di ramificazione è 1, nel nodo si osserva il valore atteso
                    if(self.branching_factors[i]==1):
                        sample=mean[0,:]
                    # altrimenti, campionamento da normale multivariata
                    else:
                        sample=np.random.multivariate_normal(mean[0,:], cov)
                        for k in range(len(sample)):
                            if sample[k]<0:
                                sample[k]=0
                    
                    self.add_node(
                        id_new_node,
                        obs=sample,
                        prob=self.nodes[parent_node]['prob'] * probs,
                        t=i + 1,
                        id=count,
                        stage=i + 1
                    )
                    self.add_edge(parent_node, id_new_node)
                    next_level.append(id_new_node)
                    count += 1
            last_added_nodes = next_level
            self.n_nodes = count
        self.leaves = last_added_nodes

    def get_leaves(self):
        # Restituisce le foglie dell'albero
        return self.leaves

    def get_history_node(self, n):
        # Dato l'indice di un nodo, restituisce tutte le osservazioni dal nodo radice ad esso
        ris = self.nodes[n]['obs'].reshape((1, self.dim_observations))
        # caso nodo radice
        if n == 0:
            return ris
        # altrimenti iterando all'indietro sui predecessori
        while n != self.starting_node:
            n = list(self.predecessors(n))[0]
            ris = np.vstack(
                (self.nodes[n]['obs'].reshape((1, self.dim_observations)), ris)
            )
        return ris