import matplotlib.pyplot as plt
import numpy as np
import config as conf

import math
from matplotlib.ticker import FormatStrFormatter

from matplotlib import rcParams
#rcParams.update({'figure.autolayout': True})

import matplotlib.ticker
class MyLocator(matplotlib.ticker.AutoMinorLocator):
    def __init__(self, n=23):
        super().__init__(n=n)
matplotlib.ticker.AutoMinorLocator = MyLocator 


# baseline:
#rc_ytick_minor = True
rc_font_size = 28
rc_xtic_labelsize = 28
rc_ytic_labelsize = 28
rc_legendtitle_fontsize = 18
rc_legend_fontsize = 20
n_ytic_min = 5
n_xtic_min = 5

#stats_by_nss:
##rc_ytick_minor = True
#rc_font_size = 28
#rc_xtic_labelsize = 28
#rc_ytic_labelsize = 28
#rc_legendtitle_fontsize = 24
#rc_legend_fontsize = 40
#n_ytic_min = 5
#n_xtic_min = 0

#plt.rcParams["ytick.minor.visible"] =  False
plt.rcParams['font.size'] = rc_font_size
plt.rcParams['xtick.labelsize'] = rc_xtic_labelsize 
plt.rcParams['ytick.labelsize'] = rc_ytic_labelsize
#plt.rcParams['legend.title_fontsize'] = 'xx-small'
plt.rcParams['legend.title_fontsize'] = rc_legendtitle_fontsize
plt.rcParams['legend.fontsize'] = rc_legend_fontsize


class Plot:
    def __init__(self, yvalues, xvalues=[], title="", labels="",
                 xlabel="", ylabel="", save=0):
        self.yvalues = yvalues
        self.xvalues = xvalues
        self.title = title
        self.save = save
        self.labels = labels
        self.xlabel = xlabel
        self.ylabel = ylabel

    def plot(self):
        plt.figure()
        plt.grid()
        plt.title(self.title)

        plt.plot(list(self.xvalues), list(self.yvalues))

        if self.save:
            plt.savefig(f"plots/{self.title}.png", bbox_inches = 'tight')
        
        plt.show()

    def plot_by_day(self):
        fig = plt.figure()
        plt.grid()
        plt.title(self.title)
        if self.ylabel:
            plt.ylabel(self.ylabel)

        # plt.axvline(80, linestyle='--', c='r')
        # plt.axvline(172, linestyle='--', c='r')
        # plt.axvline(265, linestyle='--', c='r')
        # plt.axvline(356, linestyle='--', c='r')
        #
        # ax = fig.add_subplot()
        #
        # major_ticks = np.arange(0, 365, 30)
        # minor_ticks = np.arange(0, 365, 5)
        #
        # ax.set_xticks(major_ticks)
        # ax.set_xticks(minor_ticks, minor=True)
        #
        # ax.grid(which='both')
        #
        # ax.grid(which='minor', alpha=0.2)
        # ax.grid(which='major', alpha=0.5)

        plt.plot(range(conf.SIM_LAST), list(self.yvalues))
        if len(self.xvalues):
            plt.plot(range(conf.SIM_LAST), list(self.xvalues))

        if self.save:
            plt.savefig(f"plots/{self.title}.png", bbox_inches = 'tight')
        
        plt.show()

    def plot_by_hour(self):
        fig = plt.figure()
        plt.grid()
        if self.ylabel:
            plt.ylabel(self.ylabel)

        plt.xlim((0, 23))

        ax = fig.add_subplot(1, 1, 1)
        major_ticks = np.arange(0, 15, 1)
        minor_ticks = np.arange(0, 15, 1)
        #major_ticks = np.arange(0, 23, 1)
        #minor_ticks = np.arange(0, 23, 1)
        ax.set_xticks(major_ticks)
        ax.set_xticks(minor_ticks, minor=True)

        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
        ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))
        
        plt.title(self.title)
        plt.plot(range(24), list(self.yvalues), '.-')

    def scatter(self):
        plt.figure()
        plt.grid()
        plt.title(self.title)
        
        plt.xlim((min(self.xvalues), max(self.xvalues)))

        plt.scatter(list(self.xvalues), list(self.yvalues))
        #plt.savefig(f"plots/{self.title}.eps".replace(' ', '_'), format='eps', dpi=1000, bbox_inches = 'tight')
        
        plt.show()


class MultiPlot(Plot):

    # override
    def plot(self, legend_labels):
        fig = plt.figure()
        plt.grid()
        
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)

        # plt.ylim((np.min(self.yvalues) - 1, np.max(self.yvalues) + 2))

        ax = fig.add_subplot(1, 1, 1)
        
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
        ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))
        
        #print(self.xvalues)
        
        if not isinstance(min(self.xvalues), str):        
            min_v = math.ceil(min(self.xvalues))
            max_v = math.ceil(max(self.xvalues))
            major_ticks = np.arange(min_v, max_v) #500, 1600, 100)
            #minor_ticks = np.arange(0,25,1) #500, 1600, 100)
            ax.set_xticks(major_ticks)
        
            skip_xtic = 0
            for xtic in [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500]:
                if ((max_v-min_v)/xtic)<=10:
                    skip_xtic = xtic
                    break
        
            ax.set_xticks(ax.get_xticks()[::skip_xtic])
            #ax.set_xticks(minor_ticks, minor=True)
            ax.xaxis.set_tick_params(rotation=45)
            ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
            ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))
                

        for i in range(self.yvalues.shape[0]):
            legend_title = '$T_{MAX}$  [min]' #'$N_S$' #'$T_{MAX}$  [min]'
            uom = '' #' min'
            
            import itertools
            import seaborn as sns
            
            plt.plot(self.xvalues, self.yvalues[i, :], markersize=8, linewidth=2, label= str(legend_labels[i]) + uom)
                     #label=self.labels + " " + str(legend_labels[i]) + ' min')
            for l, ms, colore in zip(ax.lines, itertools.cycle('ov^<>8sp*hHDdPX'), itertools.cycle(sns.color_palette())):
                l.set_marker(ms)
                l.set_color(colore)
                #l.set_linestyle(lst)
        plt.legend(prop={"size":10}, title= legend_title, ncol = 3,
                   columnspacing = 0.05, labelspacing=0.1, handlelength=1.3, borderpad=0.2) #self.labels)
        #plt.legend(prop={"size":16}, title= legend_title, ncol = 3,
        #           columnspacing = 0.05, labelspacing=0.1, handlelength=1.3, borderpad=0.2) #self.labels)
        
        plt.savefig(f"plots/{self.title}.eps".replace(' ', '_'), format='eps', dpi=1000, bbox_inches = 'tight')
        plt.title(self.title)
        plt.show()
        

    def single_plot(self):
        fig = plt.figure()
        plt.grid()
        plt.title(self.title)
        plt.xlabel(self.xlabel)
        plt.ylabel(self.ylabel)

        ax = fig.add_subplot()

        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
        ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))
            
        for i, txt in enumerate(self.labels):
            ax.annotate(txt, (self.xvalues[i], self.yvalues[i]), fontsize=18)

        plt.plot(self.xvalues, self.yvalues, ".-")
        # plt.legend()
        
        plt.show()

    def plot_cost_prob_loss(self, label):
        x = self.xvalues[:, 0]
        y = self.yvalues[:, 0]

        fig = plt.figure()
        plt.grid()
        
        plt.xlabel("Loss Probability")
        plt.ylabel("Daily Cost [€]")
        
        ax = fig.add_subplot()
        
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
        ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))

        #%%
        # CODE TO SHOW LABELS ON POINTS:        
#        for i, txt in enumerate(label):
#            ax.annotate(str(txt), (x[i]-0.001, y[i]-0.004), fontsize=14)
#        #plt.xscale("log")
        #%%

        for j in range(self.yvalues.shape[0]):
            plt.plot(x[j], y[j], marker="+", markersize=16, mew=2, label= str(label[j]))
        #plt.plot(x, y, "+")
        # CODE TO SHOW LEGEND FOR POINTS:
        #plt.legend(prop={"size":16}, title= '$T_{max}$ [min]', ncol = 3,
        #           columnspacing = 0.1, labelspacing=0.1, handlelength=0.8, borderpad=0.2) #self.labels)
        
        plt.legend(prop={"size":10}, title= '$T_{max}$ [min]', ncol = 3,
                   columnspacing = 0.1, labelspacing=0.1, handlelength=0.8, borderpad=0.2) #self.labels)
        
        plt.savefig(f"plots/{self.title}.eps".replace(' ', '_'), format='eps', dpi=1000, bbox_inches = 'tight')
        
        plt.title(self.title)
        plt.show()
        
    def plot_consumption_prob_loss(self, label):
        x = self.xvalues[:, 0]
        y = self.yvalues[:, 0]

        fig = plt.figure()
        plt.grid()
        
        plt.xlabel("Loss Probability")
        plt.ylabel("$E^G$ [Wh]")
        
        ax = fig.add_subplot()
        
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_ytic_min))
        ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n_xtic_min))
        
                #%%
        # CODE TO SHOW LABELS ON POINTS:        
#        for i, txt in enumerate(label):
#            ax.annotate(str(txt), (x[i]+0.002, y[i]+0.005), fontsize=19)
#        #plt.xscale("log")
        #%%

        for j in range(self.yvalues.shape[0]):
            plt.plot(x[j], y[j], marker="+", markersize=16, mew=2, label= str(label[j]))
        #plt.plot(x, y, "+")
        # CODE TO SHOW LEGEND FOR POINTS:
        
        plt.legend(prop={"size":10}, title= '$T_{max}$ [min]', ncol = 3,
                   columnspacing = 0.2, labelspacing=0.1, handlelength=0.8, borderpad=0.2)

        #plt.legend(prop={"size":16}, title= '$T_{max}$ [min]', ncol = 3,
        #           columnspacing = 0.2, labelspacing=0.1, handlelength=0.8, borderpad=0.2) #self.labels)
        
        plt.savefig(f"plots/{self.title}.eps".replace(' ', '_'), format='eps', dpi=1000, bbox_inches = 'tight')
        
        plt.title(self.title)
        plt.show()

