from sim import simulate
import config as conf
from plot import MultiPlot
from statistics_GC import AvgStatistics
import numpy as np
import json as js
from components.agent import Agent
import matplotlib.pyplot as plt


def reset_parameters():
    conf.NBSS = 20
    conf.SPV = 500
    conf.BTH = conf.C * 0.9
    conf.TMAX = 0#480 #5
    conf.F = 17 #0


def multi_plot(stats, x, legend_values, xlabel, legend_label):
    MultiPlot(stats.avg_arrivals, xvalues=x, title="Arrivals", labels=legend_label, xlabel=xlabel).plot(legend_values)
    MultiPlot(stats.avg_loss, xvalues=x, title="Losses", labels=legend_label, xlabel=xlabel).plot(legend_values)
    # MultiPlot(stats.avg_avg_wait, xvalues=x, title="Avg wait", labels=legend_label, xlabel=xlabel).plot(legend_values)
    # MultiPlot(stats.avg_avg_ready, xvalues=x, title="Avg ready batteries", labels=legend_label, xlabel=xlabel).plot(
    #     legend_values)
    MultiPlot(stats.avg_cost, xvalues=x, title="Costs", ylabel="Euro per day", labels=legend_label, xlabel=xlabel).plot(
        legend_values)
    MultiPlot(stats.avg_saving, xvalues=x, title="Savings", ylabel="Euro per day", labels=legend_label,
              xlabel=xlabel).plot(legend_values)
    MultiPlot(stats.cost_per_service, xvalues=x, title="Cost per service", ylabel="Euro per service",
              labels=legend_label, xlabel=xlabel).plot(legend_values)


def plot_stats(stats, params, label):
    MultiPlot(stats.avg_arrivals, title="Arrivals", xvalues=label, labels=label).single_plot()
    MultiPlot(stats.avg_loss, title="Losses", xvalues=label, xlabel=params, labels=label).single_plot()
    # MultiPlot(stats.avg_avg_wait, title="Waiting", xvalues=label, labels=label).single_plot()
    # MultiPlot(stats.avg_avg_ready, title="Average ready", xvalues=label, labels=label).single_plot()
    MultiPlot(stats.avg_cost, title="Costs", xvalues=label, xlabel=params,
              ylabel="Euro per day", labels=label).single_plot()

    y = np.array((stats.avg_cost.tolist(), stats.avg_saving.tolist(),
                  stats.avg_net_cost.tolist()))
    MultiPlot(y, title="Costs", xlabel=params,
              ylabel="Euro per day", xvalues=label).plot(["Grid", "Sold", "Net"])

    y = np.array((stats.avg_tot_consumption.tolist(), stats.avg_consumption.tolist(),
                  stats.avg_spv_consumption.tolist()))
    MultiPlot(y, title="Consumption", xlabel=params,
              ylabel="Energy per day [Wh/day]", xvalues=label).plot(["Tot", "Grid", "SPV"])

    MultiPlot(stats.avg_cost, stats.avg_loss_prob, title="Cost / prob loss").plot_cost_prob_loss(label)

    MultiPlot(stats.cost_per_service, stats.avg_loss_prob, title="Cost per service / prob loss").plot_cost_prob_loss(label)


if __name__ == "__main__":
    #reset_parameters()
    
    # Output folder
    output_folder = "results/output_data/"
    
    # SPV / NBSS
    spv_list = [0] + list(range(300,1600,200)) #range(500, 1600, 100))
    nbss_list = list(range(10,21,1)) #range(10, 21, 1))
    stats_by_nbss = AvgStatistics(len(nbss_list), len(spv_list))
    stats_by_spv = AvgStatistics(r=len(spv_list))

    # F / TMAX
    f_list = [0] + list(range(10,conf.NBSS + 1))
    tmax_list = [0] + list(range(60,660,60)) #60, 660, 60))
    stats_by_tmaxf = AvgStatistics(len(tmax_list), len(f_list))
    stats_by_f = AvgStatistics(r=len(f_list))
    stats_by_Tmax = AvgStatistics(r=len(tmax_list))

    # BTH
    bth_list = [bth / 100 for bth in range(60, 95, 5)]
    stats_by_bth = AvgStatistics(r=len(bth_list))
    
    #BETA DEPENDENCY
    beta_list = [0.985]

    # ARRIVAL_COEFF
    arrival_list = [conf.arrival_rate, conf.arrival_rate_2, conf.arrival_rate_3]
    stats_by_arr_rate = AvgStatistics(r=len(arrival_list))

    #%% Multiple PV sizes and NBSS
#    for spv in spv_list:
#        for nbss in nbss_list:
#            conf.SPV = spv
#            conf.NBSS = nbss
#            stats = simulate()
#            print("-")
#    
#            stats_by_nbss.compute_avg(stats, nbss_list.index(conf.NBSS), spv_list.index(conf.SPV))
#                        
#    multi_plot(stats_by_nbss, spv_list, nbss_list, "SPV", "NBSS")
    
    #%% Multiple PV sizes
#    for spv in spv_list:
#        conf.SPV = spv
#        stats = simulate()
#        print("-")
#    
#        stats_by_spv.compute_avg(stats, spv_list.index(conf.SPV))
#    
#    plot_stats(stats_by_spv, "SPV", spv_list)

    #%% Multiple NBSS values
#    for nbss in nbss_list:
#        conf.NBSS = nbss
#        stats = simulate()
#        print("-")
#    
#        stats_by_nbss.compute_avg(stats, nbss_list.index(conf.NBSS))
#    
#    plot_stats(stats_by_nbss, "NBSS", nbss_list)

    #%% Multiples values of BETA
    
    def print_results(stats):
        # Print statistics
        print("---------")
        print("Mean arrivals: %f" % (np.mean(list(stats.arrivals.values()))))
        print("Mean loss: %f" % (np.mean(list(stats.loss.values()))))
        print("Mean cost: %f" % (np.mean(list(stats.cost.values()))))
        print("Mean net cost: %f" % (np.mean(list(stats.net_cost.values()))))
        c = np.mean(list(stats.cost.values()))
        a = np.mean(list(stats.arrivals.values()))
        l = np.mean(list(stats.loss.values()))
        print("Cost per service: %f" % (c / (a - l)))
        print("Mean consumption: %f" % (np.mean(list(stats.total_consumption.values()))))
        print("Mean grid consumption: %f" % (np.mean(list(stats.consumption.values()))))
        print("Mean SPV: %f" % np.mean(list(stats.spv_production.values())))
        print("Mean saving: %f" % np.mean(list(stats.saving.values())))
    
    # Beta computing
    #beta_list = np.arange(0.012,0.019,0.001)
    #beta_list = [0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.925]
    beta_list = [0.99] #FinalRL
    
    reset_parameters()
    results_DP = []
    for beta in beta_list:
        
        print("======= SIMULATING FOR BETA = ", beta, " =======")
        stats = simulate(beta_value = beta, agent_on = True, agent_class = 1)
        results_DP.append(stats)
    
    #Plot results
    mean_loss = []
    mean_cons = []
    for idx, beta in enumerate(beta_list):
        print("======= SIMULATING FOR BETA =  ",beta,"  =======")
        stats = results_DP[idx]
        print_results(stats)
        mean_loss.append(np.mean(list(stats.loss.values())))
        mean_cons.append(np.mean(list(stats.consumption.values())))

    plt.plot(beta_list, mean_loss)
    plt.show()
    plt.plot(beta_list, mean_cons)
    plt.show()
    
    #%% Multiple BTH values
#    for bth in bth_list:
#        conf.BTH = conf.C * bth
#        stats = simulate()
#        print("-")
#    
#        # The following line has been replaced by the line afterward: is it correct?
#        #stats_by_bth.compute_avg(stats, bth_list.index(bth))
#        stats_by_bth.compute_avg(stats, bth_list.index(conf.BTH))
#    plot_stats(stats_by_bth, "Bth", bth_list)

    #%% Multiple Tmax and f values
#    for tmax in tmax_list:
#        for f in f_list:
#            conf.TMAX = tmax
#            conf.F = f
#            print(tmax, f)
#            stats = simulate()
#            print("-")
#
#            stats_by_tmaxf.compute_avg(stats, tmax_list.index(tmax), f_list.index(f))
#
#    multi_plot(stats_by_tmaxf, f_list, tmax_list, "F", "TMAX")

    #%% Multiple f values
    #print(conf.B)
#    for f in f_list:
#        conf.F = f
#        stats = simulate()
#        print("-")
#    
#        stats_by_f.compute_avg(stats, f_list.index(conf.F))
#    
#    plot_stats(stats_by_f, "F", f_list)

    #%% Multiple Tmax values
#    for tmax in tmax_list:
#        conf.TMAX = tmax
#        stats = simulate()
#        print("-")
#    
#        stats_by_Tmax.compute_avg(stats, tmax_list.index(conf.TMAX))
#    
#    plot_stats(stats_by_Tmax, "TMAX", tmax_list)
    
    # %% By arrival rate
#    for i in range(len(arrival_list)):
#        conf.arrival_rate = arrival_list[i]
#        stats = simulate()
#        print("-")
#    
#        stats_by_arr_rate.compute_avg(stats, i)
#    
#    plot_stats(stats_by_arr_rate, list(range(3)), ["3 peaks", "2 peaks", "Fixed coeff"])
