import json
import torch 
import torch.onnx
from torch import nn
from sklearn.model_selection import train_test_split
import Dataset as pDataset
from torch import optim
# import status_now_classes
import pandas as pd
# import utilities as util

filename = "PlantNN"
filepath = "../results/"+filename+".pth"

def createModel(structure : list, inSize : int):
    modelCreated = nn.Sequential()
    n_layers = len(structure)
    for i in range(n_layers):
        if i == 0:
            modelCreated.add_module('InputLayer'   , nn.Linear(inSize,structure[i]))
            modelCreated.add_module('Activation'+str(i), nn.ReLU())
        else:
            modelCreated.add_module('Layer'+str(i),nn.Linear(structure[i-1],structure[i]))
            modelCreated.add_module('Activation'+str(i),nn.ReLU())    
    modelCreated.add_module('Output',nn.Linear(structure[i],1))        
    return modelCreated

# Importing the model settings
SETTINGS_FILE       = "train_config.json"
settings            = json.load( open(SETTINGS_FILE)  )
data_directory      = settings["Data_directory"]
features            = settings["features"]
labels              = settings["labels"]
train_plants_ids    = settings["train_ids"]
test_plants_ids     = settings["test_ids"]
training_start_date = settings["training_start_date"]
training_end_date   = settings["training_end_date"]
features_to_shift   = settings["features_to_shift"]
test_start_date     = settings["test_start_date"]
test_end_date       = settings["test_end_date"]
simulation_type     = settings["SimulationType"]

# Importing the entire model checkpoint
inputLayerSize      = len(features)
for feature, value in features_to_shift.items():
    inputLayerSize = inputLayerSize + value
model = createModel(settings["NN_Structure"], inputLayerSize)
optimizer           = optim.SGD(model.parameters(),lr=settings["learning_rate"])

checkpoint = torch.load(filepath)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

# Creating a Dummy input
batch_size = 1
num_features = inputLayerSize
x = torch.randn(batch_size,num_features,requires_grad = True)

# Creating dynamic axes
<<<<<<< HEAD
dynamic_axes_labels = feature + labels
dynamic_axes = { key:{0,batch_size} for key in dynamic_axes_labels}
=======
dynamic_axes_labels = features + labels
dynamic_axes_names = { key:{0:'batch_size'} for key in dynamic_axes_labels}
>>>>>>> 4792e71 (Corrected labels on ONNX axes)

output = model(x)
print(model)

torch.onnx.export(
    model,
    x,
    '../results/'+filename+'.onnx',
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
<<<<<<< HEAD
    input_names=[],
    output_names=[],
    dynamic_axes=dynamic_axes
=======
    input_names=features,
    output_names=labels,
    dynamic_axes=dynamic_axes_names
>>>>>>> 4792e71 (Corrected labels on ONNX axes)
)


