from data_manager import DatasetManager
from queue import PriorityQueue
import config as conf
import numpy as np
from scipy.stats import poisson
import matplotlib.pyplot as plt

#Simulation parameters
TIME_GRANULARITY = 5 #[min]
LAST_K = round(60*24/TIME_GRANULARITY)

#Model parameters
BAT_STATES = 100
BAT_UNIT = conf.C/BAT_STATES
BAT_TH = round((conf.C-conf.BTH)/BAT_UNIT)

#Control parameters
BETA = 0
GAMMA = 0
DISC = 0.9#0.996
U_MAX = conf.CR*TIME_GRANULARITY/120
U_VECTOR = np.array([0,round(U_MAX/BAT_UNIT)])
VALUE_ITERATIONS = 2

#%% GLOBAL VARIABLES

#Data vectors
PV_vector = []
Price_vector = []
Arrival_prob = []

#%% FUNCTIONS

#Arrival probability for n cars
def arrival_prob(mu, n=1):
    prob = 1 - poisson.cdf(n-1, mu)
    if(prob <= 0):
        prob = np.finfo(float).eps
    return prob

#Arrival probability for M cars (vector)
def arrival_prob_vector(mu, M=1):
    prob_v = []
    for m in range(M):
        m_aux = m+1
        prob = 1 - poisson.cdf(m_aux-1, mu)
        if(prob <= 0):
            prob = np.finfo(float).eps
        prob_v.append(prob)
    return np.array(prob_v)

#Ramp function for single value
def positive(x):
    if x <= 0:
        return 0
    else:
        return x
    
#Ramp function for single value
def positive_v(x):
    return (x>=0)*x

#Cost function
def mean_step_cost_function(k, x, u):
    x_next = positive(x-u)
    g = (1-BETA)*Price_vector[k]*positive(u-PV_vector[k])
    g += Arrival_prob[k]*BETA*int(x_next > BAT_TH)
    return g

#Total mean cost function
def tot_mean_step_cost_function(k, x_next, u, e_price, PV, mu, M):
    tot_e = np.sum(u)
    e_cost = (1-BETA)*e_price*positive(tot_e - PV)
    p_k = arrival_prob_vector(mu, M)
    g = e_cost
    g += BETA*np.sum((x_next > BAT_TH)*p_k)
    return g

#State transition vector
def next_state_v(k, x_next, u, arrivals, M):
    v1 = [1]*arrivals
    v2 = [0]*(M-arrivals)
    arrival_v = np.array(v1 + v2)
    bat_taken = arrival_v*(x_next <= BAT_TH)
    x = bat_taken*(BAT_STATES-1) + (1-bat_taken)*x_next
    return x

#State transition
def next_state(k, x, u, arrival):
    x_next = positive(x-u)
    bat_taken = arrival*int(x_next <= BAT_TH)
    x = bat_taken*(BAT_STATES-1) + (1-bat_taken)*x_next
    return x

#J Computation
def J_computation(k,x,u, DP_table):
    J = mean_step_cost_function(k, x, u) 
    J += DISC*Arrival_prob[k]*DP_table[0][next_state(k, x, u, 1)][1]
    J += DISC*(1-Arrival_prob[k])*DP_table[0][next_state(k, x, u, 0)][1]
    return J

#Value iteration
def value_iteration(DP_table, iterations):
    
    for i in range(iterations):
        
        #Set final cost from last iteration
        DP_table_new = np.copy(DP_table[0])
        DP_table = [DP_table_new]
        
        #Run the DP algorithm
        for k in range(LAST_K-1,-1,-1):
            J_k = [0]*BAT_STATES
            for x in range(BAT_STATES):
                min_Q = float("Inf")
                opt_u = None
                for u in U_VECTOR:
                    Q_aux = J_computation(k,x,u,DP_table)
                    if (Q_aux < min_Q):
                        min_Q = Q_aux
                        opt_u = u
                        J_k[x] = [opt_u, min_Q]
            DP_table.insert(0, J_k)
        
    #Return resulting table
    return DP_table

#Plot control matrix
def plot_control(DP_table):
    #Control matrix
    u_matrix = np.zeros((BAT_STATES, LAST_K))
    for i in range(u_matrix.shape[0]):
        for j in range(u_matrix.shape[1]):
            u_matrix[i][j] = DP_table[j][i][0]
    plt.matshow(u_matrix) 

#Plot costs matrix
def plot_costs(DP_table):
    #Conts matrix
    J_matrix = np.zeros((BAT_STATES, LAST_K))
    for i in range(J_matrix.shape[0]):
        for j in range(J_matrix.shape[1]):
            J_matrix[i][j] = DP_table[j][i][1]
    #plt.matshow(J_matrix) 
    
#%% Structure

class Agent:
    
    #Constructor
    def __init__(self, BSS, beta = 0):
        global BETA
        BETA = beta
        print("Agent initiated with BETA = ", BETA)
        self.bss = BSS
        self.control_table = self.computeControlTables()
        
    #Evaluate control
    def evaluate_control(self, k, enum_sockets, u, mu, M, e_price, PV):
        x = []
        for s in enum_sockets:
            bat = int(BAT_STATES*(s[1].battery.charge/conf.C))
            bat_state = BAT_STATES - bat
            x.append(bat_state)
        x = np.array(x)
        x_next = np.sort(positive_v(x-u))
        g = tot_mean_step_cost_function(k, x_next, u, e_price, PV, mu, M)
        rollout = self.expected_rollout(k, x_next, u, enum_sockets, M, mu)
        Q_out = g + DISC*rollout
        return Q_out
    
    #Compute expected total cost
    def expected_rollout(self, k, x_next, u, s, M, mu):
        J = 0
        for arrivals in range(M):
            p_k = arrival_prob(mu, arrivals)
            x_future = np.sort(next_state_v(k, x_next, u, arrivals, M))
            for s_idx in range(len(s)):              
                J += p_k*self.control_table[s_idx][k][x_future[s_idx]][1]
        return J
    
    #Compute table
    def computeControlTables(self):
        #print("computing for (month, day) = (", conf.MONTH," , ", conf.CURRENT_DAY, ")")
        tables = []
        for i in range(len(self.bss.sockets)):
            #print("computing for socket = ", i+1)
            tables.append(self.computeSingleSocketTable(i+1))
        return tables
    
    #Compute single sockettable
    def computeSingleSocketTable(self, n_socket):
        
        #Create a Dataset manager
        dm = DatasetManager()
        
        #Generate time vector
        time = np.array(range(0,LAST_K))
        
        #Load day statistics from data manager
        global PV_vector
        PV_vector = []
        global Price_vector
        Price_vector = []
        global Arrival_prob
        Arrival_prob = []
        for k in time: 
            h = int(k*TIME_GRANULARITY/60)
            PV_vector.append(dm.get_PV_power(conf.MONTH, conf.CURRENT_DAY, h, 1)* (TIME_GRANULARITY/60) * (1/conf.NBSS))
            Price_vector.append(dm.get_prices_electricity(conf.MONTH, conf.DAY, h)* 1e-6 * (TIME_GRANULARITY/60))
            Arrival_prob.append(arrival_prob(TIME_GRANULARITY/(conf.arrival_rate[h]), n_socket))
            
        #Run the DP Algorithm and fill the DP Lookup table
        DP_table = []
        
        #Final cost computation
        J_N = [[0,0]]*BAT_STATES
        for i in range(BAT_STATES):
            J_N[i] = [0,GAMMA*(float(i)**2)]
        
        #DP Algorithm
        DP_table.append(J_N)
        for k in range(LAST_K-1,-1,-1):
            J_k = [0]*BAT_STATES
            for x in range(BAT_STATES):
                min_Q = float("Inf")
                opt_u = None
                for u in U_VECTOR:
                    Q_aux = J_computation(k,x,u,DP_table)
                    if (Q_aux < min_Q):
                        min_Q = Q_aux
                        opt_u = u
                        J_k[x] = [opt_u, min_Q]
            DP_table.insert(0, J_k)
        
        DP_table = value_iteration(DP_table, VALUE_ITERATIONS)
        
        return DP_table
    
    #Apply control
    def apply_control(self, time):
        
        #Create a Dataset manager
        dm = DatasetManager()
        
        #Discretize time
        pre_k = time%(60*24) #Minute of the day
        k = int(pre_k/TIME_GRANULARITY)
        h = int(k*TIME_GRANULARITY/60)
        
        #Sort sockets according to state of battery
        sockets = [x for x in self.bss.sockets if (x.busy == True and x.battery.booked == False)]
        enum_sockets = sorted(enumerate(sockets), key=lambda x:x[1].battery.charge, reverse=True)
        
        #System data
        mu = TIME_GRANULARITY/(conf.arrival_rate[h])
        M = len(enum_sockets)
        e_price = dm.get_prices_electricity(conf.MONTH, conf.DAY, h)* 1e-6 * (TIME_GRANULARITY/60)
        PV = dm.get_PV_power(conf.MONTH, conf.CURRENT_DAY, h, 1)* (TIME_GRANULARITY/60)
        
        #Intial control parameters
        u_opt = [0]*M
        Q_opt = self.evaluate_control(k, enum_sockets, u_opt, mu, M, e_price, PV)
        
        #Explore control table to apply control
        socket_cnt = 0
        for s in enum_sockets:
            for u in U_VECTOR:
                u_aux = np.copy(u_opt)
                u_aux[socket_cnt] = u
                Q_aux = self.evaluate_control(k, enum_sockets, u_aux, mu, M, e_price, PV)
                if (Q_aux < Q_opt):
                    Q_opt = Q_aux
                    u_opt = u_aux
            socket_cnt += 1
            
        #Apply control
        socket_cnt = 0
        for s in enum_sockets:
            if(u_opt[socket_cnt] != 0 and s[1].battery.charge < conf.BTH):
                if (s[1].is_charging == False):
                    s[1].battery.last_update = time
                    s[1].battery.charging_resumes += 1
                s[1].is_charging = True
            else:
                s[1].is_charging = False
            socket_cnt += 1
            
        pass
        
        
        
        