import torch
import torch.optim as optim
import torch.nn as nn
import torch.functional as F
CUDA = torch.cuda.is_available()
if CUDA:
    import cupy as cp
else:
    import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

from torch.utils.data import DataLoader, TensorDataset

# My files imports
import single


def singleTraining(train_dataset,test_dataset, batch_size, model, loss_fn, optimizer,n_epochs, id):

    train_loader        = DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = False )
    test_loader         = DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False )

    simObject           = single.SingleTest(model=model, loss_fn=loss_fn, optimizer=optimizer)        
    print(f'Model selected: {simObject.model}\nOptimizer: {simObject.optimizer}\nLoss function: {simObject.loss_fn}\nCuda: {CUDA}')
    simObject.set_loaders(train_loader=train_loader, val_loader=test_loader)
    simObject.setModelID(id)
    if CUDA:
        simObject.train(n_epochs=n_epochs,seed=cp.random.randint(10000).get())
    else:
        # simObject.train(n_epochs=n_epochs,seed=np.random.randint(10000))
        simObject.train(n_epochs=n_epochs,seed=9365)
    
    return simObject