import json
import math
import Dataset as pDataset
from torch.utils.data import TensorDataset
import torch 
from torch import nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import single_simulation as ST
import matplotlib.pyplot as plt
from itertools import product
import pathlib
import pandas as pd
from datetime import datetime

SETTINGS_FILE       = "train_config.json"
settings            = json.load( open(SETTINGS_FILE)  )
data_directory      = settings["Data_directory"]
features            = settings["features"]
labels              = settings["labels"]
train_plants_ids    = settings["train_ids"]
test_plants_ids     = settings["test_ids"]
training_start_date = settings["training_start_date"]
training_end_date   = settings["training_end_date"]
features_to_shift   = settings["features_to_shift"]
test_start_date     = settings["test_start_date"]
test_end_date       = settings["test_end_date"]
simulation_type     = settings["SimulationType"]

metric              = settings["Metric"]
metricThreshold     = settings["MetricThreshold"]

#Hyperparameters
learning_rate       = settings["learning_rate"]
epochs              = settings["epochs"]
batch_size          = settings["batch_size"]
dtype = torch.float32

current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
simulation_path = str(pathlib.Path(__file__).parent.parent)+"/results/"+current_time+"_"+simulation_type    

def createDirectories():
    pathlib.Path.mkdir(pathlib.Path(simulation_path), parents = True, exist_ok=True)
    pathlib.Path.mkdir(pathlib.Path(simulation_path+"/models"), parents = True, exist_ok=True)


def createModel(structure : list, inSize : int):
    modelCreated = nn.Sequential()
    n_layers = len(structure)
    for i in range(n_layers):
        if i == 0:
            modelCreated.add_module('InputLayer'   , nn.Linear(inSize,structure[i]))
            modelCreated.add_module('Activation'+str(i), nn.ReLU())
        else:
            modelCreated.add_module('Layer'+str(i),nn.Linear(structure[i-1],structure[i]))
            modelCreated.add_module('Activation'+str(i),nn.ReLU())    
    modelCreated.add_module('Output',nn.Linear(structure[i],1))        
    return modelCreated

def save_outputs_plots():
    plots_directory = pathlib.Path(simulation_path+"/plots/outputs")
    plots_directory.mkdir(parents=True, exist_ok=True)

    for plant_id in sorted(test_dataset_original.id.unique()):
        # Set output files path
        pdf_file = f'Plant_{plant_id}.pdf'
        pdf_file_path = plots_directory / pdf_file
        #Filter for id
        data_filtered = test_dataset_original.loc[(test_dataset_original['id'] == plant_id)]
        # Draw the plot
        fig,ax1 = plt.subplots(figsize=(16,9))
        for feature in features:
            ax1.plot(data_filtered['Date'],data_filtered[feature],label=feature)

        ax2 = ax1.twinx()
        for feature in ['Status', 'Predicted']:
            ax1.plot(data_filtered['Date'],data_filtered[feature],label=feature)

        ax1.set_xlabel('Date')  # Add an x-label to the axes.
        ax1.set_title(f'Plant {plant_id}')  # Add a title to the axes.
        ax1.legend()  # Add a legend.
        # Save the plot
        plt.savefig(pdf_file_path)

def filterAndSaveModels(listOfModels : list, filteringMetric : str, threshold : float):
    """Function used to loop filter the models based on a certain metric threshold. Model that 
       do not satisfy the performance threshold are discarded, the remaining are saved 
    """
    if filteringMetric == "Accuracy":
        listOfModels = list(filter( lambda x : x.final_val_accuracy > threshold,listOfModels  ))    
        listOfModels.sort(key=lambda x: x.final_val_accuracy,reverse=True)
    if filteringMetric == "F1Score":
        listOfModels = list(filter( lambda x : x.final_val_f1Score > threshold,listOfModels  )) 
        listOfModels.sort(key=lambda x: x.final_val_f1Score,reverse=True)
    if filteringMetric == "MCC":
        listOfModels = list(filter( lambda x : x.final_val_MCC > threshold,listOfModels  ))    
        listOfModels.sort(key=lambda x: x.final_val_MCC,reverse=True)

    for mod in listOfModels:
        mod.export_to_onnx(simulation_path+"/models/Model_"+str(mod.modelID)+".onnx",features,labels)
    return listOfModels   

def simulationRecap(filteringMetric, listOfModels):
    """
        Function to write a brief recap of the performed simulation. 
        These function uses two variables that are defined outside its scope and not passed as arguments:
        - In the case of Single simulation it uses the "res" variable which is a model object for the trained
            model to print info about it on a file 
        - In the case of any sweep simulation it uses the list of model called modelsList to print information
            about each trained model
    """
    recap_file = open(simulation_path+"/simulation_recap.txt","w")
    recap_file.write("SIMULATION TYPE:\t\t"+simulation_type+"\n\n")
    recap_file.write("DATASET_MIN:\n"+str(dataset_min)+"\n\nDATASET_MAX:\n"+str(dataset_max)+"\n\n")
        
    recap_file.write("MODELS RANKING\n")
    if filteringMetric == "Accuracy":
        recap_file.write("MODEL ID"+"\t\t"+"ACCURACY\n")
        for mod in listOfModels:
            recap_file.write(str(mod.modelID)+"\t\t\t\t"+str(mod.final_val_accuracy)+"\n")
        recap_file.write("_"*100+"\n\n")
        recap_file.write("MODEL CHARACTERISTICS\n")
        for mod in listOfModels:
            recap_file.write("MODEL ID : "+str(mod.modelID)+"\n")
            recap_file.write(str(mod.model)+"\n")

    if filteringMetric == "F1Score":
        recap_file.write("MODEL ID"+"\t\t"+"F1 SCORE\n")
        for mod in listOfModels:
            recap_file.write(str(mod.modelID)+"\t\t\t\t"+str(mod.final_val_f1Score)+"\n")
        recap_file.write("_"*100+"\n\n")
        recap_file.write("MODEL CHARACTERISTICS\n")
        for mod in listOfModels:
            recap_file.write("MODEL ID : "+str(mod.modelID)+"\n")
            recap_file.write(str(mod.model)+"\n")

    if filteringMetric == "MCC":
        recap_file.write("MODEL ID"+"\t\t"+"MCC\n")
        for mod in listOfModels:
            recap_file.write(str(mod.modelID)+"\t\t\t\t"+str(mod.final_val_MCC)+"\n")
        recap_file.write("_"*100+"\n\n")
        recap_file.write("MODEL CHARACTERISTICS\n")
        for mod in listOfModels:
            recap_file.write("MODEL ID : "+str(mod.modelID)+"\n")
            recap_file.write(str(mod.model)+"\n")   

def simulationRecapJSON(filteringMetric, listOfModels):
    #Writing minmax JSON
    minmax_json_filename = simulation_path+"/minmax.json"
    minmax_dict=[]
    for feature in features:
        minmax_dict.append({feature:{"max":float(dataset_max[feature]),"min":float(dataset_min[feature])}})
    with open(minmax_json_filename, 'w') as json_file:
        json.dump(minmax_dict,json_file,indent=4) 

    if simulation_type == "NeuronsSweep":
        model_json_filename = simulation_path+"/models.json"
        model_dict=[]
        
        #Writing list of models JSON
        for mod in listOfModels:
            filteringMetricDict = {
                "Accuracy": float(mod.final_val_accuracy),
                "F1Score" : float(mod.final_val_f1Score),
                "MCC"     : float(mod.final_val_MCC)
            }
            model_dict.append({'modelID':int(mod.modelID),str(filteringMetric):filteringMetricDict[filteringMetric]})
        with open(model_json_filename, 'w') as json_file:
            json.dump(model_dict,json_file,indent=4)
    
modelsList  = []
modelID     = 0


dataset, dataset_min, dataset_max  = pDataset.PlantsDatasetCreation(data_directory,features.copy(),labels,train_plants_ids,train_plants_ids,training_start_date,training_end_date,features_to_shift=features_to_shift)
#train_dataset, test_dataset = train_test_split(dataset,shuffle=True,test_size=0.2)
train_dataset = dataset
test_dataset = dataset

# To select only test plants
# test_dataset_original = test_dataset.loc[test_dataset['id'].isin(test_plants_ids)]
# train_dataset_original = train_dataset.loc[train_dataset['id'].isin(train_plants_ids)]
test_dataset_original = test_dataset.loc[test_dataset['id'].isin(test_plants_ids)].head(int(0.2*test_dataset.shape[0]))
train_dataset_original = train_dataset.loc[train_dataset['id'].isin(train_plants_ids)].head(int(0.8*train_dataset.shape[0]))
train_dataset       = train_dataset_original
test_dataset        = test_dataset_original

# To pass to tensors you need to exclude some columns, such ad Date and id. In the case of the training also the labels 
train_features      = torch.tensor(train_dataset.loc[:,~train_dataset.columns.isin(["Date","id","Status"])].values, dtype=torch.float32)
train_labels        = torch.tensor(train_dataset.loc[:,train_dataset.columns.isin(["Status"])].values, dtype=torch.float32)
test_features       = torch.tensor(test_dataset.loc[:,~test_dataset.columns.isin(["Date","id","Status"])].values, dtype=torch.float32)
test_labels         = torch.tensor(test_dataset.loc[:,test_dataset.columns.isin(["Status"])].values, dtype=torch.float32)
train_dataset       = TensorDataset(train_features,train_labels)
test_dataset        = TensorDataset(test_features,test_labels)

inputLayerSize      = test_features.size()[1]
outputLayerSize     = test_labels.size()[1]

if simulation_type == "Single":

    model               = createModel(settings["NN_Structure"], inputLayerSize)
    loss                = nn.BCEWithLogitsLoss()
    optimizer           = optim.SGD(model.parameters(),lr=learning_rate)
    res                 = ST.singleTraining(train_dataset=train_dataset,test_dataset=test_dataset,batch_size=batch_size,model=model,loss_fn=loss,optimizer=optimizer,n_epochs=epochs,id=1)
    modelID +=1
    modelsList.append(res)
    #Plotting results  
    pathlib.Path.mkdir(pathlib.Path(simulation_path+"/plots"), parents = True, exist_ok=True)
    res.plot_losses().savefig(simulation_path+"/plots/losses.pdf")
    res.plot_accuracy().savefig(simulation_path+"/plots/accuracy.pdf")
    res.plot_MCC().savefig(simulation_path+"/plots/MCC.pdf")
    res.plot_f1_score().savefig(simulation_path+"/plots/f1_score.pdf")
    settings["Seed"]=int(res.seed)
    settings["TrainWeight"]=float(train_dataset_original["Status"].mean())
    settings["TestWeight"]=float(test_dataset_original["Status"].mean())
    #plt.show()

    # Testing new model
    test_dataset_original['Predicted']=res.model(test_features).detach().numpy()

    # Plotting predictions
    test_dataset_original.sort_values('Date', inplace=True)
    save_outputs_plots()

if simulation_type == "NeuronsSweep":
    hidden_layers_number    = settings["NSW_Hidden_layers"] 
    start_neurons_number    = settings["NSW_Start_Neurons"]
    stop_neurons_number     = settings["NSW_Stop_Neurons"]
    step_neurons            = settings["NSW_Step_Neurons"]
    logarithmic_step        = settings["NSW_Logarithmic_Step"]
    step_type               = settings["NSW_Step_Type"]

    if step_type == "lin":
        neurons_values = [i for i in range(start_neurons_number,stop_neurons_number+1,step_neurons) ]
    if step_type == "log":
        neurons_values = [math.floor(start_neurons_number*pow(10,i)) for i in range(0,math.ceil(math.log(stop_neurons_number,10)),logarithmic_step) ]

    # To get all possible values of neurons given a fixed number of hidden layers you have to perform the 
    # cartesian product 
    structures = [i for i in product(neurons_values,repeat=hidden_layers_number)]

    for single_struct in structures:
        model = createModel(single_struct, inputLayerSize)
        loss                = nn.BCEWithLogitsLoss()
        optimizer           = optim.SGD(model.parameters(),lr=learning_rate)
        res = ST.singleTraining(train_dataset=train_dataset,test_dataset=test_dataset,batch_size=batch_size,model=model,loss_fn=loss,optimizer=optimizer,n_epochs=epochs,id=modelID)
        modelID +=1
        modelsList.append(res)

# Models filtering and ranking by performance   
createDirectories() 
modelsList = filterAndSaveModels(modelsList,metric,metricThreshold)
   
# Generate simulation recap
simulationRecap(metric, modelsList)
simulationRecapJSON(metric, modelsList)

#Copy setting file into results directory
setting_copy = open(simulation_path+"/"+current_time+"_"+simulation_type+"_settings_file.json","w")
json.dump(settings,setting_copy,indent=2)

#Copy used input plots
