############################## PACCHETTI, LIBRERIE E DEVICE ##############################
# NumPy
import numpy as np

# Tensorly
from tensorly.decomposition import parafac

#########################################################################################
#########################################################################################

###################################### PARAFAC - 3D #####################################
def PARAFAC(X, R = 1, als_iter_max = 50, als_tol = 1e-8, n_max_nochange = 5,
            gd_iter_max = 50, gd_tol = 1e-8, random_state = 300890):
    
    # Inizializzazione prng
    prng = np.random.default_rng(random_state)
    
    # Estrazione della dimensionalità del tensore
    t, n, m = X.shape
    
    # Numero di elementi del tensore
    N = t * n * m
    
    # Bounds per l'inizializzazione
    lb, ub = X.min(), X.max()
    
    # Calcolo degli unfolding di X
    X_u0 = unfold_X(X, axis = 0)
    X_u1 = unfold_X(X, axis = 1)
    X_u2 = unfold_X(X, axis = 2)

    # Inizializzazione della decomposizione
    V0 = prng.uniform(lb, ub, (t, R))
    V1 = prng.uniform(lb, ub, (n, R))
    V2 = prng.uniform(lb, ub, (m, R))
    
    # Inizializzazione delle variabili di controllo 
    als_iter_max = int(als_iter_max)
    gd_iter_max = int(gd_iter_max)
    n_max_nochange = int(n_max_nochange)
    loss_track = np.empty((als_iter_max + 1))
    loss_track[0] = 0
    count_nochange = 0

    # Iterazioni ALS
    for k in range(als_iter_max):

        # Aggiornamento tensori
        Z = unfold_tensors(V1, V2)
        V0 = SteepestGradientDescent(V0, Z, X_u0, gd_iter_max, gd_tol)
        Z = unfold_tensors(V0, V2)
        V1 = SteepestGradientDescent(V1, Z, X_u1, gd_iter_max, gd_tol)
        Z = unfold_tensors(V1, V0)
        V2 = SteepestGradientDescent(V2, Z, X_u2, gd_iter_max, gd_tol)

        # Calcolo della funzione obiettivo
        loss_track[k + 1] = ((X_u2 - (V2.dot(Z.T))) ** 2).sum() / (2 * N)

        # Calcolo della variazione della soluzione
        delta_loss = loss_track[k + 1] - loss_track[k]
        if abs(delta_loss) < als_tol:
            count_nochange += 1
        elif count_nochange > 0:
            count_nochange = 0

        # Criterio di stop
        if count_nochange == n_max_nochange:
            break

    return (V0, V1, V2), loss_track[1: k + 2]



# Steepest Gradient Descent
def SteepestGradientDescent(V, Z, X_unfold, gd_iter_max, gd_tol):

    # Inizializzazione della matrice Z^T 
    Z_trnsp = Z.T

    # Steepest Gradient Descent
    for k in range(gd_iter_max):

        # Calcolo di X_hat
        X_hat = V.dot(Z_trnsp)

        # Calcolo del gradiente dell'MSE rispetto a V
        psi = X_unfold - X_hat
        grad_V = psi.dot(Z) / X_unfold.size

        # Calcolo dello step-length ottimale 
        phi = grad_V.dot(Z_trnsp)
        alpha = - (phi * psi).sum()
        alpha /= (phi ** 2).sum()

        # Update della matrice V
        V -= alpha * grad_V

        # Criterio di stop: check della norma del gradiente
        norm_V = np.linalg.norm(grad_V)
        if norm_V < gd_tol:
            break

    return V



# Funzione per l'unfolding di un tensore 3D
def unfold_X(X, axis):

    # Trasposizione tensore 
    if axis == 0:
        X = X.copy()
    elif axis == 1: 
        X = np.swapaxes(X, 0, 1)
    else: 
        X = np.swapaxes(X, 0, 2)

    # Estrazione della dimensionalità del tensore 
    I, J, T = X.shape

    # X_unfold = np.ndarray (shape: (I, K = J * T))
    K = J * T
    X_unfold = np.empty((I, K))

    # Loop lungo l'asse 2
    for t in range(T):
        # Ad ogni iterazione si memorizza sequenzialmente la t-esima  
        # matrice all'intero del nuovo np.ndarray X_unfold
        X_unfold[:, t * J : (t + 1) * J] = X[:, :, t]

    return X_unfold



# Funzione per l'unfolding di T = T1.dot(T2.T), dati T1, T2
def unfold_tensors(V1, V2):

    # Estrazione delle dimensionalità di V1 e V2 
    v1, v2 = V1.shape[0], V2.shape[0]

    # Z = np.ndarray (shape: (v1 * v2, R))
    # Per coerenza di notazione nel codice, l'asse di lunghezza
    # R viene mantenuto in posizione 1
    Z = np.empty((v1 * v2, V1.shape[1]))

    # Loop lungo l'asse 2 
    for idx, row in enumerate(V2):
        # Ad ogni iterazione si moltiplicano le colonne di V1
        # per i rispettivi scalari delle colonne di V2 e le matrici
        # risultanti vengono allocate in T_unfold
        Z[idx * v1 : (idx + 1) * v1, :] = V1 * row

    return Z

#########################################################################################
#########################################################################################

###################################### ORTHOGONALITY ####################################
def orthogonality(V):
    V = V / np.linalg.norm(V, axis = 0)
    R = V.shape[1]
    N = R ** 2
    return ((V.T.dot(V) - np.eye(R)) ** 2).sum() / N

#########################################################################################
#########################################################################################

########################################## PFA ##########################################
class PFA:
    
    
    
    # Inizializzazione della classe
    def __init__(self, R = 1, n_iter_max = 100, tol = 1e-6, 
                 linesearch = False, random_state = 300890):
        
        # Inizializzazione degli attributi della classe 
        self.prng = np.random.default_rng(random_state)
        self.features_tensors = None
        self.features_tensors_unfold = None
        
        # Parametri per la decomposizione parafac
        self.parafac_params = {'rank': R, 'n_iter_max': n_iter_max,
                              'tol': tol, 'linesearch': linesearch, 'random_state': random_state}
    
    

    # Metodo fit
    def fit(self, X):
        
        # Calcolo della decomposizione di X
        self.features_tensors = parafac(tensor = X, **self.parafac_params)[1][1:]
        V1, V2  = self.features_tensors
        self.features_tensors_unfold = unfold_tensors(V1, V2)
        
        return self
    
    
    
    # Metodo transform
    def transform(self, X, n_iter_max = 50, tol = 1e-8):
        
        # Calcolo della 'proiezione' di X rispetto ai tensori delle features appresi
        X_unfold = unfold_X(X = X, axis = 0)
        X_hat = self.prng.uniform(X.min(), X.max(), (X_unfold.shape[0], self.parafac_params['rank']))
        X_hat = SteepestGradientDescent(V = X_hat, Z = self.features_tensors_unfold, 
                                        X_unfold = X_unfold, gd_iter_max = n_iter_max, gd_tol = tol)
        
        return X_hat
    
#########################################################################################
#########################################################################################

########################################## PFA ##########################################
class PARAFAC_Convolution:
    
    
    
    # Inizializzazione classe
    def __init__(self, kernel_size = 4, stride = 1, R = 1,
                n_iter_max = 100, tol = 1e-6, linesearch = False, random_state = 300890):
        
        # Inizializzazione attributi
        self.prng = np.random.default_rng(random_state)
        self.kernels = None
        self.convolution_params = (kernel_size, stride)
        self.parafac_params = {'rank': R, 'n_iter_max': n_iter_max,
                              'tol': tol, 'linesearch': linesearch, 'random_state': random_state}
        
        
        
    # Metodo fit
    def fit(self, X):
        
        # Calcolo del numero di kernels
        n, m = X.shape[1:]
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione kernels 
        self.kernels =  [None] * (n_kernels * m_kernels)
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Decomposizione 
                self.kernels[idx] = parafac(tensor = X[:, i_init : i_stop, j_init : j_stop],
                                            **self.parafac_params)[1][1:]
                
                # Aggiornamento indice
                idx += 1
        
        return self
    
    
    
    # Metodo transform
    def transform(self, X, n_iter_max = 50, tol = 1e-8):
        
        # Estrazione numero di samples
        n_samples = X.shape
        
        # Calcolo del numero di kernels
        n_samples, n, m = X.shape
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione X_hat
        X_hat = np.empty((n_samples, rank, n_kernels, m_kernels))
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Unfolding della regione da approssimare
                x_unfold = unfold_X(X = X[:, i_init : i_stop, j_init : j_stop], axis = 0)
                
                # Unfolding dei kernels
                Z = unfold_tensors(self.kernels[idx][0], self.kernels[idx][1])
                
                # Inizializzazione x_hat
                x_hat = self.prng.uniform(x_unfold.min(), x_unfold.max(), (n_samples, rank))
                
                # Calcolo di x_hat 
                X_hat[:, :, i, j] = SteepestGradientDescent(V = x_hat, Z = Z, X_unfold = x_unfold, 
                                                           gd_iter_max = n_iter_max, gd_tol = tol)
                
                # Aggiornament indice
                idx += 1
        
        return X_hat
    
    
    
    # Metodo transform
    def flat_transform(self, X, n_iter_max = 50, tol = 1e-8):
        
        # Estrazione numero di samples
        n_samples = X.shape
        
        # Calcolo del numero di kernels
        n_samples, n, m = X.shape
        kernel_size, stride = self.convolution_params
        rank = self.parafac_params['rank']
        n_kernels = round((n - kernel_size) / stride) + 1
        m_kernels = round((m - kernel_size) / stride) + 1
        
        # Inizializzazione X_hat
        X_hat = np.empty((n_samples, rank * n_kernels * m_kernels))
        
        # Inizializzazione indicizzazione
        idx = 0
        
        # Iterazioni sui kernels
        for i in range(n_kernels):
            for j in range(m_kernels):
                
                # Calcolo degli indici
                i_init, j_init = i * stride, j * stride
                i_stop, j_stop = i_init + kernel_size, j_init + kernel_size
                
                # Unfolding della regione da approssimare
                x_unfold = unfold_X(X = X[:, i_init : i_stop, j_init : j_stop], axis = 0)
                
                # Unfolding dei kernels
                Z = unfold_tensors(self.kernels[idx][0], self.kernels[idx][1])
                
                # Inizializzazione x_hat
                x_hat = self.prng.uniform(x_unfold.min(), x_unfold.max(), (n_samples, rank))
                
                # Calcolo di x_hat 
                X_hat[:, idx * rank : (idx + 1) * rank] = SteepestGradientDescent(V = x_hat, Z = Z, X_unfold = x_unfold, 
                                                                                  gd_iter_max = n_iter_max, 
                                                                                  gd_tol = tol)
                
                # Aggiornament indice
                idx += 1
        
        return X_hat