import gurobipy as grb
from gurobipy import quicksum
import time
import numpy as np
from functions import *

from solvers.CLSP import CLSP

class CLSP_FR(CLSP):
    """
    Classe per la risoluzione del problema CLSP (Capacitated Lot-Sizing Problem) 
    con domanda stocastica, con la mateuristica Fix&Relax
    """
    def __init__(self,**setting):
        self.name = "CLSP_FR"
        self.setting = setting

    def populate(self, instance, tree, I0, t, fixed):
        """
        Questo metodo, a partire dai dati di input della classe, popola il modello definendone 
        variabili, vincoli e funzione obiettivo; le variabili s sono definite continue o intere
        a seconda dell'istante temporale; l'input fixed riporta le variabili che sono state
        fissate in iterazioni precedenti (dunque non sono più variabili)
        """
        #parametri del modello (costi, capacità massima)
        h=instance.h
        g=instance.g
        f=instance.f
        proc_time=instance.proc_time
        R=instance.R
        setup_times=instance.setup_times
        num_items=instance.num_items
        time_periods=instance.time_periods
        N=instance.N
        
        #parametri albero di scenari
        p=probs(tree)
        T=time_period(tree)
        L=tree.get_leaves()
        a=predecessor(tree)
        omega=ancestor(tree)
        sigma=successors_t(tree)
        d=np.ndarray(shape=(num_items,N))
        for i in range(N):
            if i==0:
                d[:,i]=tree.get_history_node(i)
            else:
                d[:,i]=tree.get_history_node(i)[T[i],:]
        N_L=[]
        for i in range(N):
            if i not in(L):
                N_L.append(i)
                
        #suddivisione nodi in base a periodo di tempo, rispetto a t
        
        #insieme degli indici dei nodi il cui periodo temporale è uguale a t
        i1=[i for i in range(len(T)) if T[i]==t]
        #insieme degli indici dei nodi il cui periodo temporale è maggiore di t
        i2=[i for i in range(len(T)) if T[i]>t]
        
        #numero di nodi il cui periodo temporale è uguale a t 
        #(numero di variabili di set up intere)
        Nt=len(i1)
        #numero di nodi il cui periodo temporale è maggiore di t
        #(numero di variabili di set up rilassate)
        Ngt=len(i2)
        
        #indice del primo nodo il cui periodo temporale è uguale a t
        n1=min(i1)
        #indice del primo nodo il cui periodo temporale è maggiore di t
        n2=max(i1)+1
                
        model = grb.Model(self.name)
        
        #VARIABILI DEL MODELLO
        
        y = model.addMVar(
            shape=(num_items,time_periods,N),
            lb=0,
            vtype=grb.GRB.CONTINUOUS,
            name="y"
        )
        
        self.y=y
        
        z = model.addMVar(
            shape=(num_items,N),
            lb=0,
            vtype=grb.GRB.CONTINUOUS,
            name="z"
        )
        
        self.z=z
        
        I = model.addMVar(
            shape=(num_items,N),
            lb=0,
            vtype=grb.GRB.CONTINUOUS,
            name="I"
        )
        
        self.I=I
        
        #PRIMO TIME PERIOD, NODO 0
        
        if (t==0):
            
            #VARIABILI DI SET UP INTERE
            s = model.addMVar(
                shape=(num_items,1),
                vtype=grb.GRB.BINARY,
                name="s"
            )
            
            self.s=s
            
            #VARIABILI DI SET UP RILASSATE
            relaxed = model.addMVar(
                shape=(num_items,N-1),
                vtype=grb.GRB.CONTINUOUS,
                lb=0,
                ub=1,
                name="relaxed"
            )
            
            self.relaxed=relaxed
            
            #VINCOLI E FUNZIONE OBIETTIVO NEL CASO SPECIFICO
            
            model.addConstrs((y[i,t,0]<=(max(d[i,j] for j in sigma[0,t]))*s[i,0]) for i in range(num_items) for t in range(T[0]+1,time_periods))
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*relaxed[i,n-1]) for i in range(num_items) for n in range(1,N) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,T[0],0]<=d[i,0]*s[i,0]) for i in range(num_items))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*relaxed[i,n-1]) for i in range(num_items) for n in range(1,N))
            
            model.addConstr((quicksum(quicksum(proc_time[i]*y[i,t,0] for t in range(T[0],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*s[i,0] for i in range(num_items))<=R))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*relaxed[i,n-1] for i in range(num_items))<=R) for n in range(1,N))
            
            model.setObjective(quicksum(p[n]*quicksum(f[i]*s[i,n]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n2)) + quicksum(p[n]*quicksum(f[i]*relaxed[i,n-n2]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n2,N)) + quicksum(p[n]*quicksum(quicksum(h[i]*(t-T[n]) * y[i,t,n] for t in range(T[n]+1,time_periods)) for i in range(num_items)) for n in N_L), grb.GRB.MINIMIZE)
        
        #ULTIMO TIME PERIOD
        
        elif (t==time_periods-1):
            
            #VARIABILI DI SET UP INTERE
            s = model.addMVar(
                shape=(num_items,Nt),
                vtype=grb.GRB.BINARY,
                name="s"
            )
            
            self.s=s
            
            #VINCOLI E FUNZIONE OBIETTIVO NEL CASO SPECIFICO
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*fixed[i,n]) for i in range(num_items) for n in range(n1) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*s[i,n-n1]) for i in range(num_items) for n in range(n1,N) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*fixed[i,n]) for i in range(num_items) for n in range(n1))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*s[i,n-n1]) for i in range(num_items) for n in range(n1,N))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*fixed[i,n] for i in range(num_items))<=R) for n in range(n1))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*s[i,n-n1] for i in range(num_items))<=R) for n in range(n1,N))
            
            model.setObjective(quicksum(p[n]*quicksum(f[i]*fixed[i,n]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n1))+quicksum(p[n]*quicksum(f[i]*s[i,n-n1]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n1,N)) + quicksum(p[n]*quicksum(quicksum(h[i]*(t-T[n]) * y[i,t,n] for t in range(T[n]+1,time_periods)) for i in range(num_items)) for n in N_L), grb.GRB.MINIMIZE)
        
        #TIME PERIODS INTERMEDI
        
        else:
            
            #VARIABILI DI SET UP INTERE
            s = model.addMVar(
                shape=(num_items,Nt),
                vtype=grb.GRB.BINARY,
                name="s"
            )
            
            self.s=s
            
            #VARIABILI DI SET UP RILASSATE
            relaxed = model.addMVar(
                shape=(num_items,Ngt),
                vtype=grb.GRB.CONTINUOUS,
                lb=0,
                ub=1,
                name="relaxed"
            )
            
            self.relaxed=relaxed
            
            #VINCOLI E FUNZIONE OBIETTIVO NEL CASO SPECIFICO
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*fixed[i,n]) for i in range(num_items) for n in range(n1) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*s[i,n-n1]) for i in range(num_items) for n in range(n1,n2) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,t,n]<=(max(d[i,j] for j in sigma[n,t]))*relaxed[i,n-n2]) for i in range(num_items) for n in range(n2,N) for t in range(T[n]+1,time_periods))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*fixed[i,n]) for i in range(num_items) for n in range(n1))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*s[i,n-n1]) for i in range(num_items) for n in range(n1,n2))
            
            model.addConstrs((y[i,T[n],n]<=d[i,n]*relaxed[i,n-n2]) for i in range(num_items) for n in range(n2,N))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*fixed[i,n] for i in range(num_items))<=R) for n in range(n1))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*s[i,n-n1] for i in range(num_items))<=R) for n in range(n1,n2))
            
            model.addConstrs((quicksum(quicksum(proc_time[i]*y[i,t,n] for t in range(T[n],time_periods)) for i in range(num_items))+ quicksum(setup_times[i]*relaxed[i,n-n2] for i in range(num_items))<=R) for n in range(n2,N))
            
            model.setObjective(quicksum(p[n]*quicksum(f[i]*fixed[i,n]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n1))+quicksum(p[n]*quicksum(f[i]*s[i,n-n1]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n1,n2)) + quicksum(p[n]*quicksum(f[i]*relaxed[i,n-n2]+h[i]*I[i,n]+g[i]*z[i,n] for i in range(num_items)) for n in range(n2,N)) + quicksum(p[n]*quicksum(quicksum(h[i]*(t-T[n]) * y[i,t,n] for t in range(T[n]+1,time_periods)) for i in range(num_items)) for n in N_L), grb.GRB.MINIMIZE)
            
#VINCOLI INDIPENDENTI DA VARIABILE S (FISSA, BINARIA O RILASSATA)
            
        model.addConstrs((I[i,a[n]]+quicksum(y[i,T[n],omega[n,t]] for t in range(T[n]))+y[i,T[n],n]==d[i,n]+I[i,n]-z[i,n]) for i in range(num_items) for n in range(1,N))
        
        model.addConstrs((I0[i]+y[i,T[0],0]==d[i,0]+I[i,0]-z[i,0]) for i in range(num_items))
        
        model.update()
        return model

    def get_solution(self, instance, model, tree, t, time_limit=None, gap=None, verbose=False):
        """
        Questo metodo risolve un modello definito, restituendo il valore della funzione obiettivo,
        la decisione di produzione all'istante iniziale, le corrispondenti variabili di set up,
        il tempo di risoluzione del modello dell'iterazione t e il valore calcolato delle variabili intere
        all'iterazione t (per fissarle all'iterazione successiva)"""
        #gap (per criterio di arresto) del problemo misto intero
        if gap:
            model.setParam('MIPgap', gap)
        #tempo limite per la risoluzione del problemo misto intero
        if time_limit:
            model.setParam(grb.GRB.Param.TimeLimit, time_limit)
        #opzione per avere un output che mostra i dettagli della risoluzione
        #del modello o meno
        if verbose:
            model.setParam('OutputFlag', 1)
        else:
            model.setParam('OutputFlag', 0)
        if verbose:
            print ('Solving a model with: '+str(model.NumConstrs)+' constraints')
            print ('    and: ' +str(model.NumVars)+ ' variables')

        #risoluzione del modello e tempo computazionale dell'iterazione  
        start = time.time()
        model.optimize()
        end = time.time()
        comp_time = end - start
        
        
        num_items = instance.num_items
        time_periods=instance.time_periods
        N=instance.N
        T=time_period(tree)
        
        #VARIABILI S CALCOLATE, DA AGGIORNARE IN FIXED ALL'ITERAZIONE SUCCESSIVA
        
        #insieme dei nodi il cui periodo temporale è uguale a t
        i1=[i for i in range(len(T)) if T[i]==t]

        #numero di nodi il cui periodo temporale è uguale a t 
        #(numero di variabili di set up intere)
        Nt=len(i1)
        
        computed_s=np.zeros((num_items,Nt))
        
        for i in range(num_items):
            for j in range(Nt):
                computed_s[i,j]=self.s[i,j].X
        
        #CALCOLO DELLE DECISIONI INIZIALI (QUANTITA' PRODOTTE PER OGNI ITEM AL TEMPO 0 Y 
        # E RELATIVE VARIABILI DI SETUP S)
        
        sol = np.zeros((num_items,time_periods,N))
        
        for i in range(num_items):
            for j in range(time_periods):
                for k in range(N):
                    sol[i,j,k] = self.y[i,j,k].X
        
        Y=[0]*num_items
        S=[0]*num_items
        
        for i in range(num_items):
            Y[i]=sum(sol[i,:,0])
            if Y[i]>0:
                S[i]=1
        
        #VALORE DELLA FUNZIONE OBIETTIVO
        of = model.getObjective().getValue()
            
        model.reset()
        return of, Y, S, comp_time, computed_s
        
    def solve(
        self, instance, tree, I0, time_limit=None, gap=None, verbose=False
    ):
        """
        Questo è il metodo dove il problema viene definito e risolto
        :param instance: dizionario contenente tutti i parametri necessari alla definizione del modello
        :param tree: albero di scenari che rappresenta la domanda stocastica
        :I0: valore iniziale del magazzino
        :param time_limit: per interrompere il solver gurobi dopo "time_limit" secondi; se non è presente
        il solver si interrompe quando trova una soluzione 'distante meno di' "gap" dalla 
        soluzione ottima
        :param gap: il solver gurobi si ferma quando trova una soluzione 'distante meno di' "gap"
        dalla soluzione ottima; il valore standard è 0.0001
        :param verbose: parametro (True/False) da passare per avere un output che stampa i dettagli della 
        soluzione del modello o meno
        :return: valore della funzione obiettivo, decisioni al primo stadio (Y e S) e tempo computazionale del solver
        Fix and Relax(somma dei tempi di tutte le iterazioni)
        """
        #NUMERO DI ISTANTI DI TEMPO DELL'ALBERO DI SCENARI
        time_periods=instance.time_periods
        #NUMERO DI PRODOTTI
        num_items=instance.num_items
        #NUMERO DI NODI
        N=instance.N
        #MATRICE IN CUI SI SALVANO LE VARIABILI INTERE GIA' FISSATE
        fixed=np.zeros((num_items,N))
        T=time_period(tree)
        tot_time=0
        
        #AD OGNI ISTANTE DI TEMPO "SI POPOLA" IL MODELLO, PER AGGIORNARE LE VARIABILI FISSE, QUELLE BINARIE E QUELLE RILASSATE
        for t in range(time_periods):
            model = self.populate(instance, tree, I0, t, fixed)
            [of,Y,S,comp_time,computed_s]=self.get_solution(instance, model, tree, t, time_limit=time_limit, gap=gap, verbose=verbose)
            
            i1=[i for i in range(len(T)) if T[i]==t]
            n1=min(i1)
            #LE VARIABILI CHE ERANO BINARIE A QUESTA ITERAZIONE VENGONO "FISSATE" AL VALORE CHE E' STATO CALCOLATO
            #(NELLA MATRICE fixed). ALL'ITERAZIONE SUCCESSIVA SARANNO FISSE
            for k in i1:
                fixed[:,k]=computed_s[:,k-n1]
            #SOMMA DEL TEMPO COMPUTAZIONALE DEL SOTTO PROBLEMA DELL'ISTANTE t A QUELLI PRECEDENTI
            #(PER OTTENERE TEMPO TOTALE)
            tot_time=tot_time+comp_time
            
        return of,Y,S,tot_time