############################## PACCHETTI, LIBRERIE E DEVICE ##############################
# Pytorch
import torch
from torch import nn

# Tensorly - Torch
from tensorly.decomposition import parafac
import tltorch

# Math
from math import log, sqrt, pi

# Set device 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#########################################################################################
#########################################################################################

############################### FEATURES MAPPING - PARAFAC ##############################
# Classe per il features mapping
class PARAFAC_Convolution:
    
    
    
    # Inizializzazione classe        
    def __init__(self, kernel_size = 4, stride = 1, R = 1,
                n_iter_max = 100, tol = 1e-6, linesearch = False, device = 'cpu', random_state = 300890):
        
        # Inizializzazione attributi
        self.prng = torch.Generator(device = device)
        self.prng.manual_seed(random_state)
        self.device = device
        self.kernels = None
        self.convolution_params = (kernel_size, stride)
        self.parafac_params = {'rank': R, 'n_iter_max': n_iter_max,
                              'tol': tol, 'linesearch': linesearch, 'random_state': random_state}
        
        
        
    # Metodo fit
    def fit(self, X):
        
        # Calcolo del numero di kernels
        n, m = X.shape[1:]
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione kernels 
        self.kernels =  [None] * (n_kernels * m_kernels)
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Decomposizione 
                self.kernels[idx] = parafac(tensor = X[:, i_init : i_stop, j_init : j_stop],
                                            **self.parafac_params)[1][1:]
                
                # Aggiornamento indice
                idx += 1
        
        return self
    
    
    
    # Metodo transform
    def transform(self, X, n_iter_max = 50, tol = 1e-8):
        
        # Estrazione numero di samples
        n_samples = X.shape
        
        # Calcolo del numero di kernels
        n_samples, n, m = X.shape
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione X_hat
        X_hat = torch.empty(n_samples, rank, n_kernels, m_kernels, device = self.device)
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Unfolding della regione da approssimare
                x_unfold = unfold_X(X = X[:, i_init : i_stop, j_init : j_stop], axis = 0, device = self.device)
                
                # Unfolding dei kernels
                T = unfold_tensors(self.kernels[idx][0], self.kernels[idx][1], device = self.device)
                
                # Inizializzazione x_hat
                x_hat = torch.empty(n_samples, rank, 
                                    device = self.device).uniform_(x_unfold.min(), x_unfold.max(), generator = self.prng)
                
                # Calcolo di x_hat 
                X_hat[:, :, i, j] = SteepestGradientDescent(P = x_hat, T = T, X = x_unfold, 
                                                           gd_iter_max = n_iter_max, gd_tol = tol)
                
                # Aggiornament indice
                idx += 1
        
        return X_hat
    
    
    
    # Metodo transform
    def flat_transform(self, X, n_iter_max = 50, tol = 1e-8):
        
        # Estrazione numero di samples
        n_samples = X.shape
        
        # Calcolo del numero di kernels
        n_samples, n, m = X.shape
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione X_hat
        X_hat = torch.empty(n_samples, rank * n_kernels * m_kernels, device = self.device)
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Unfolding della regione da approssimare
                x_unfold = unfold_X(X = X[:, i_init : i_stop, j_init : j_stop], axis = 0, device = self.device)
                
                # Unfolding dei kernels
                T = unfold_tensors(self.kernels[idx][0], self.kernels[idx][1], device = self.device)
                
                # Inizializzazione x_hat
                x_hat = torch.empty(n_samples, rank, 
                                    device = self.device).uniform_(x_unfold.min(), x_unfold.max(), generator = self.prng)
                
                # Calcolo di x_hat 
                X_hat[:, idx * rank : (idx + 1) * rank] = SteepestGradientDescent(P = x_hat, T = T, X = x_unfold, 
                                                                                  gd_iter_max = n_iter_max, 
                                                                                  gd_tol = tol)
                
                # Aggiornament indice
                idx += 1
        
        return X_hat



# Steepest Gradient Descent
def SteepestGradientDescent(P, T, X, gd_iter_max, gd_tol):

    # Unfolding di T = T1 x T2
    T_trnsp = T.T

    # Steepest Gradient Descent
    for k in range(gd_iter_max):

        # X_hat_up = <P, T_unfold'>
        X_hat = torch.matmul(P, T_trnsp)

        # Calcolo del gradiente rispetto a P
        X_diff = X_hat - X
        grad_P = torch.matmul(X_diff, T)

        # Calcolo dello ste-length ottimale 
        Z = torch.matmul(grad_P, T_trnsp)
        alpha = torch.mul(Z, X_diff).sum().item()
        alpha /= torch.square(Z).sum().item()

        # Aggiornamento della matrice P
        P -= alpha * grad_P

        # Criterio di stop: check della norma del gradiente
        norm_P = torch.norm(grad_P)
        if norm_P < gd_tol:
            break

    # Return: P aggiornata
    return P



# Funzione per l'unfolding di un tensore 3D
def unfold_X(X, axis, device):

    # Swap degli assi (l'asse fissato si sposta in posizione 0)
    if axis == 0:

        # Non avviene lo swap, axis == 0 è già in posizione 0
        X = X.clone()

    elif axis == 1: 

        # Swap tra asse 0 e 1
        X = torch.swapaxes(X, 0, 1)

    else: 

        # Swap tra asse 0 e 2
        X = torch.swapaxes(X, 0, 2)

    # Estrazione della shape del tensore per il calcolo della 
    # dimensione dell'unfolding
    I, J, T = X.shape

    # X_unfold = np.ndarray (shape: (I, K = J * T))
    K = J * T
    X_unfold = torch.empty(I, K, device = device)

    # Loop lungo l'attuale asse 2, unfolding di X
    for t in range(T):

        # Ad ogni iterazione si memorizza sequenzialmente il t-esimo layer 
        # all'intero del nuovo tensore X_unfold
        X_unfold[:, t * J : (t + 1) * J] = X[:, :, t]

    # Return: X_unfold
    return X_unfold



# Funzione per l'unfolding di T = T1.dot(T2.T), dati T1, T2
def unfold_tensors(T1, T2, device):

    # Estrazione dimensioni di T1 e T2 per inizializzazione di T_unfold
    t1, t2 = T1.shape[0], T2.shape[0]
    T_unfold = torch.empty(t1 * t2, T1.shape[1], device = device)

    # Loop lungo l'asse 2 definito da T2
    for idx, row in enumerate(T2):

        # la funzione enumerate permette di scorrere le righe
        # di T2 (row) con i rispettivi indici (idx)
        T_unfold[idx * t1 : (idx + 1) * t1, :] = torch.mul(T1, row)

    # Return: T_unfold
    return T_unfold

#########################################################################################
#########################################################################################

######################################## LENET5 #########################################
class LeNet5(nn.Module):

    # Inizializzazione classe
    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        # Layers convoluzionali
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1,
                      padding = 2),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.BatchNorm2d(6),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.BatchNorm2d(16)
        )
        
        # Percettrone multistrato
        self.classifier = nn.Sequential(
            nn.Linear(in_features = 400, out_features = 120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )

    # Forward propagation
    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        y_hat = self.classifier(x)
        return y_hat
    
#########################################################################################
#########################################################################################

####################################### PFNET-RBF #######################################
# Funzione di attivazione RBF
class RBFActivation(nn.Module):
    def __init__(self):
        super(RBFActivation, self).__init__()

    def forward(self, input):
        return torch.exp(-.5*torch.pow(input, 2))
    
# PFNET - RBF
class MLP(torch.nn.Module):
    
    # Inizializzazione classe
    def __init__(self, R, shape, output_dim):
            super(MLP, self).__init__()
            
            # Normalizzazione features mapping
            input_dim = R * shape ** 2
            self.std = nn.Sequential(
                nn.BatchNorm2d(R),
            )
            
            # Percettrone multistrato
            self.layers = nn.Sequential(
                
                # Kernel trick
                nn.Linear(input_dim, input_dim),
                nn.BatchNorm1d(162),
                RBFActivation(),
                
                # Percettrone
                nn.Linear(input_dim, output_dim)
            )
            
    # Forward propagation
    def forward(self, x):
        y = self.std(x)
        y = self.layers(y.view(y.shape[0], -1))
        return y

#########################################################################################
#########################################################################################

###################################### TRAIN-TEST #######################################
# Train
def train(net, data_loader, loss_function, optimizer):
    
    # Inizializzazione variabili
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Training mode 
    net.train()

    # Loop sulle batch
    for idx, (X, y) in enumerate(data_loader):
        
        # Riallocazione variabili
        X = X.to(device)
        y = y.to(device)

        # Predizione
        y_hat = net(X)

        # Calcolo della loss
        loss = loss_function(y_hat, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Aggiornamento delle variabili
        samples += X.size(0)
        cumulative_loss += loss.item()
        _, y_pred = y_hat.max(1) 
        cumulative_accuracy += y_pred.eq(y).sum().item()

    return cumulative_loss / samples, (cumulative_accuracy / samples) * 100



# Test 
def test(net, data_loader, loss_function):
    
    # Inizializzazione variabili
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Evaluation mode
    net.eval()

    # Esclusione del calcolo dei gradienti
    with torch.no_grad():

        # Loop sulle batch
        for idx, (X, y) in enumerate(data_loader):
            
            # Riallocazione variabili
            X = X.to(device)
            y = y.to(device)

            # Predizione
            y_hat = net(X)

            # Calcolo della loss
            loss = loss_function(y_hat, y)

            # Aggiornamento variabili
            samples += X.size(0)
            cumulative_loss += loss.item()
            _, y_pred = y_hat.max(1)
            cumulative_accuracy += y_pred.eq(y).sum().item()

    return cumulative_loss / samples, (cumulative_accuracy / samples) * 100