import numpy as np
import math

import scipy.optimize
import functools
import matplotlib.pyplot as plt

from model import Solution
from harwell_boeing_read import hb_read 

def matrix_extraction(M_directory,K_directory, Fn):
    
    # extract M, and K from an external .txt written in Harwell-Boeing format
    txt = """\nExtracting mass and stiff matrix...\t"""
    print(txt, end=' ')
    
    M = np.array(hb_read(M_directory))
    K = np.array(hb_read(K_directory))
    
    Fef = np.array([0 for i in range(len(M))])
    Fef[Fn[0][0]] = Fn[0][1]
    
    # C: proportional dumping matrix
    alpha = beta = 0.00045
    C = alpha * K + beta * M

    txt = """done!"""
    print(txt)
    
    return M, C, K, Fef

def ordinate_matrix(M, C, K, Fef, contact_properties = 0):
    
    if contact_properties != 0:
        # Index for ordinating M and K matrix
            # first section: right contact dofs
            # second section: inner dofs + modes
            # last section: left contact dofs
        contact_index = [contact_properties.Contact_dof[i] 
                         for i in range(len(contact_properties.Contact_dof))
                         ]
        inner_index = [i for i in range(len(M)) if not i in contact_index]  
        
        index = [contact_properties.Contact_dof_r[i] 
                 for i in range(len(contact_properties.Contact_dof_r))
                 ]
        index.extend(
            [
                contact_properties.Contact_dof_l[i] 
                for i in range(len(contact_properties.Contact_dof_l))
                ]
                     )
        index.extend(inner_index)
        
        M = M[index, :]
        M = M[:, index]
        K = K[index, :]
        K = K[:, index]
        C = C[index, :]
        C = C[:, index]
        Fef = Fef[index]
        
        # Update con.Contact_dof array with the new order
        contact_properties.Contact_dof = [
            i for i in range(len(contact_properties.Contact_dof))
            ]
        contact_properties.Contact_dof_r = [
            i for i in range(len(contact_properties.Contact_dof)//2)
            ]
        contact_properties.Contact_dof_l = [
            i for i in range(len(
                contact_properties.Contact_dof)//2, 
                len(contact_properties.Contact_dof)
                )
            ]
    else:
        pass
    
    return M, C, K, Fef

def modal(M, C, K, F, freq_range, reverse=False):
    
    # Modal analysis: linear solution
    
    # solve linear system and get natural frequencies
    omega_n, psi = scipy.linalg.eig(K, M)
    omega_n = np.sqrt(np.array(omega_n.real))
    f_n = omega_n/math.pi/2
    f_n.sort()
    
    # Frequency range
    fmin = freq_range[0]
    fmax = freq_range[1]
    if len(freq_range) == 2:
        step = (freq_range[1]-freq_range[0])/100.
    else:
        step = freq_range[2]
    
    if reverse:
        f = np.arange(fmax, fmin, -step)
    else:
        f = np.arange(fmin, fmax, step)
    omega = f * 2. * math.pi
    
    psi_t = np.transpose(psi)
    x = q = 1j * np.zeros((len(M), len(omega)))
    
    alpha_n = [
        np.sqrt(
            psi_t[i] @ M @ np.reshape(
                [psi_t[i][j] for j in range(len(psi))], (len(M), 1))
                ) for i in range(len(M)
            )
        ]

    alpha_n = [alpha_n[i][0] for i in range(len(alpha_n))]
    
    phi = [
        [
            psi[i][j] / alpha_n[j] for j in range(len(M))
        ] for i in range(len(M))
    ]
    
    # dumping matrix reduction
    C_mod = np.transpose(phi) @ C @ phi

    for k in range(len(omega)):
        om = omega[k]
        
        A = [
                [
                1./(omega_n[j]**2 - om**2 + 1j * om * C_mod[j][j])
                if i == j else 0. for j in range(len(M))
            ]
            for i in range(len(M))
        ] @ np.transpose(phi) @ F
        
        for j in range(len(q)):
            q[j][k] = A[j]      # Amplitude in modal coordinates

    x = phi @ q      # Amplitude in physic coordinates

    return x, f_n, f

def compl2real(xcompl):
    # From a complex number (or a vector of complex elements) to a
    # vector in which the first element is the real part, the second
    # element is the immaginary part.
    
    if xcompl == []:
        nc = 0
    else:
        nc = len(xcompl)
    xreal = np.array([0. for i in range(nc*2)])
    xreal[0::2] = xcompl[0:].real
    xreal[1::2] = xcompl[0:].imag

    return tuple(xreal)

def real2compl(xreal):
    # From a vector of real elements compose a vector of complex elements,
    # which the real part is the first element and the imaginary part is the
    # second element.
    nc = len(xreal)
    ximag = 1j * np.array([0. for i in range(nc//2)])
    ximag = xreal[0::2] + 1j * xreal[1::2]

    return np.array(ximag)

def dynamic_balance(x, slip, con, F, om, cs = False, N_blade = 1, h = 0):

    nH = con.nH
    Ndof = len(slip.x_slip)

    # From real numbers to complex numbers.
    X = real2compl(x)

    FNL = 1j * np.array(
            [
                [
                    0. for nh in range(nH)
                ] for i in range(len(F))
            ]
        )
    if cs == False:
        for cd in con.Contact_dof:
            if (cd + 1)%3:
                # Tangential displacement vector
                DX = [
                    X[Ndof * (nh) + cd] for nh in range(nH) 
                ]
                DX.insert(0, 0.)
                # Fourie coefficients of the tangential (U) and normal (V) displacements
                U = tuple(DX)
                V = tuple(
                    [
                        con.N0/con.Kn if i == 0 
                        else 0. for i in range(len(U))
                    ]
                )    

            else:
                # Normal displacement vector
                DZ = [
                    X[Ndof * (nh) + cd] for nh in range(nH)
                ]            
                DZ.insert(0, con.N0/con.Kn)
                V = tuple(DZ)
                U = tuple([0. for i in range(len(V))])

            # Fourie coefficients of the tangential (T) and normal (N) contact forces.
            T, N = contact_model(U, V, con)
            
            if (cd + 1)%3: 
                FNL[cd][:] = -np.array(T[1 : nH + 1])
            else:
                FNL[cd][:] = -np.array(N[1 : nH + 1])
     
    elif cs == True:
        
        phi = h * 2 * np.pi / N_blade
        
        for cd in con.Contact_dof_r:
            # tangential displacement (x and y direction)
            if (cd + 1)%3:          
                DX = [
                    X[Ndof * (nh) + cd] -
                    X[Ndof * (nh) + con.Contact_dof_l[con.Contact_dof_r.index(cd)]] * 
                    np.exp(-1j * ((nh + 1) * phi)) for nh in range(nH)]
                DX.insert(0, 0.)
    
                # Fourie coefficients of the tangential (U) and normal (V) displacements
                U = tuple(DX)
                V = tuple(
                    [
                        con.N0/con.Kn if i == 0 
                        else 0. for i in range(len(U))
                    ]
                )
                
            # normal contact (z direction)
            else:
                DZ = [
                    X[Ndof * (nh) + cd] -
                    X[Ndof * (nh) + con.Contact_dof_l[con.Contact_dof_r.index(cd)]] * 
                    np.exp(-1j * ((nh + 1) * phi)) for nh in range(nH)]
                DZ.insert(0, con.N0/con.Kn)
    
                # Uncoupled MHBM --> 0th harmonics is not used.
                V = tuple(DZ)
                V = tuple([0. for i in range(len(U))
                    ]
                )
    
            # Fourie coefficients of the tangential (T) and normal (N) contact forces.
            T, N = contact_model(U, V, con)
            
            if (cd + 1)%3: 
                FNL[cd][:] = -np.array(T[1 : nH + 1])
            else:
                FNL[cd][:] = -np.array(N[1 : nH + 1])
                
            FNL[con.Contact_dof_l[con.Contact_dof_r.index(cd)]] = -np.array([
                FNL[cd][nh] * np.exp(1j * ((nh + 1) * phi))
                for nh in range(nH)
                ])

    res = []
    res_tmp = 1j * np.zeros(len(F))
    A = 1j * np.zeros(len(F))
    B = 1j * np.zeros(len(F))
    
    Mff = np.array(slip.M)
    Cff = np.array(slip.C)
    Kff = np.array(slip.K)

    for nh in range(nH):

        index = Ndof * nh + np.array([i for i in range(Ndof)])

        for i in range(len(F)):
            A[i] = F[i][nh] + FNL[i][nh]

        j = -1
        for i in index:
            j += 1
            B[j] = X[i]

        res_tmp = np.linalg.solve(
            (
                Kff - 
                ((nh + 1) * om)**2 * Mff + 
                1j * ((nh + 1) * om) * Cff
            ), A
        ) -  B
        res.append(res_tmp)

    res = np.reshape(np.array(res), (np.size(res),))
    # From complex residual to real residual.
    return compl2real(res)

@functools.lru_cache(maxsize=64)
def contact_model(U, V, con):
    nhbm = len(U) - 1

    # From Fourier Coefficient to Periodical displacements in Time Domain

    # IFFT for tangential relative displacement u(t)
    # IFFT for normal relative displacement u(t)
    UU = 1j * np.array([0. for i in range(con.Nstep)])
    UU[0] = U[0]
    for i in range(1, len(U)):
        UU[i] = np.real_if_close(U[i]/2.)
        UU[-i] = np.real_if_close(np.conj(U[i])/2.)

    u = np.real_if_close(np.fft.ifft(con.Nstep * np.array(UU)), tol = 1000)

    # IFFT for normal relative displacement u(t)
    VV = 1j * np.array([0. for i in range(con.Nstep)])
    VV[0] = V[0]

    for i in range(1, len(V)):
        VV[i] = (V[i])/2.
        VV[-i] = np.conj(V[i])/2.

    v = np.real_if_close(np.fft.ifft(con.Nstep * np.array(VV)), tol = 1000)

    # Slip.
    w = [0. for i in range(con.Nstep)]

    # Tangential force in time domain.
    f  = [0. for i in range(con.Nstep)]

    # Normal force in time domain.
    n = con.Kn * v        

    # Slip --> Stick transition ID.
    ID = [0 for i in range(con.Nstep)]

    # Initial condition.

    nt = 0 # First time instant.
    f[nt] = 0  # Tangential force at t = 0.
    w[nt] = u[nt] # Slip at at t = 0.
    ID[nt] = con.Nstep

    # =====================================
    # Hysteresis Cycle.
    # =====================================
    # Reference: C.Siewert, L.Panning, J.Wallaschek, C.Richter, Multiharmonic 
    # Forced Response Analysis of a Turbine Blading Coupled by Nonlinear 
    # Contact Forces. Journal of Engineering for Gas Turbines and Power 132 
    # (Issue 8) (2010) pp.082501.1-082501.9.

    check_loop = 0
    check_end = 0

    while check_loop == 0:

        # ==========
        # Index update.
        nt += 1
        # ==========
        # Update time step index < Nstep.
        if nt >= con.Nstep:
            check_end = 1
            nt = nt - con.Nstep
            f_ref = f[nt]
        # ==========
        # Tangential Force Calculation.

        # If the normal load is equal to 0, the contact is open.
        if n[nt] <= 0:
            # Separation.
            f[nt] = 0.
            w[nt] = u[nt]
            ID[nt] = nt-1
            # If the normal load is not equal to 0, the contact is closed.
        else:
            # Predictor: the contact is supposed to be in stick.
            if nt > 0:
                f[nt] = con.Kt * (u[nt] - w[nt - 1])
                w[nt] = w[nt-1]
                ID[nt] = ID[nt-1]
            else:
                f[nt] = con.Kt * (u[nt] - w[-1])
                w[nt] = w[-1]
                ID[nt] = ID[-1]

                # Corrector: if the tangential force exceeds the Coulomb limit,
                # the contact is actually in slip. The tangential force and the slip variable
                # are updated accordingly.

            if abs(f[nt]) > con.mu * n[nt]:
                f[nt] = np.sign(f[nt]) * con.mu * n[nt] # Correction of the tangential force.
                w[nt] = u[nt] - f[nt] / con.Kt     # Correction of slip.
                ID[nt] = nt - 1
        if ID[nt] == 0:
            ID[nt] = con.Nstep

        # ==========
        # Check the loop convergence.
        # Tolerance on F(T)-F(0).

        f_tol = 1e-5 * max(abs(np.array(f)))

        if check_end == 1:
            if abs(f[nt] - f_ref) <= f_tol:
                check_loop = 1

    # =====================================
    # From periodical forces in time domain to Fourier Coefficients.

    # FFT for tangential force.
    TT = np.fft.fft(f) / len(f)
    T = tuple(
        np.array(
            [
                2 * TT[i] if i > 0 
                else TT[0] 
                for i in range(nhbm + 1)
            ]
        )
    )

    # FFT for normal force.
    NN = np.fft.fft(n) / len(n)
    N = tuple(
        np.array(
            [
                2 * NN[i] if i > 0 
                else NN[0] 
                for i in range(nhbm + 1)
            ]
        )
    )

    return T, N

def transformation_matrix(N, x_r, x_i, h):    
    phi = -h * 2 * np.pi / N
    I_d = np.identity(len(x_r))
    I_i = np.identity(len(x_i))
    
    T = np.block([[I_d, np.zeros((len(x_r), len(x_i)))], 
                  [I_d * np.exp(1j * phi) , np.zeros((len(x_r), len(x_i)))],
                  [np.zeros((len(x_i),      len(x_r))), I_i]])
    T_transpose = np.conj(np.transpose(T))
    
    return T, T_transpose

def stick_contact_cs(sol, con, N, num_of_freq):
    
    # right set: contact right surface
    x_r = [
           sol.x_slip[i][0] for i in con.Contact_dof_r
           ]
    # inner set: everything except right set
    x_i = [
           sol.x_slip[i][0] for i in range(len(sol.x_slip)) 
           if not i in con.Contact_dof
           ]
    
    sol_stick_cs = []    
    for h in range(N//2 + 1):
            
        T, T_transpose = transformation_matrix(N, x_r, x_i, h)
        
        M_cs = T_transpose @ sol.M @ T
        C_cs = T_transpose @ sol.C @ T
        K_cs = T_transpose @ sol.K @ T
        Fef_cs = T_transpose @ sol.Fef
        
        sol_stick_cs.append(
            Solution(
                M_cs, C_cs, K_cs, Fef_cs, con.Contact_dof)
            )
        
        plt.figure(4)
        for i in range(num_of_freq):
            if i == 0:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'ro', label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'ro')
            elif i == 1:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'bo', label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'bo')
            elif i == 2:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'go', label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'go')
            elif i == 3:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'ko', label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'ko')
            elif i == 4:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'yo', label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i], 'yo')
            else:
                lab = 'Famiglia #' + str(i + 1)
                if h == 0:
                    plt.plot(h, sol_stick_cs[h].f_n[i], label = lab)
                else:
                    plt.plot(h, sol_stick_cs[h].f_n[i])
        
        
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15),
          fancybox=True, shadow=0, ncol=3)
        #plt.ylim([60, 470])
        plt.grid(b=True, which='major', linestyle='--')
        plt.xticks(np.arange(0, N//2 + 1, 1))
        plt.xlabel('Indice armonico - h')
        plt.ylabel('f [Hz]')
        #plt.tight_layout()
        plt.savefig('blade_cyclic_stick.png', dpi = 300)
        
    data = [["{0:,.2f}".format(sol_stick_cs[h].f_n[i]) + "\t" 
             if len("{0:,.3f}".format(sol_stick_cs[h].f_n[i])) < 8 
             else 
             "{0:,.3f}".format(sol_stick_cs[h].f_n[i])
             for i in range(num_of_freq)] 
                for h in range(N//2 + 1)
                ]
    
    with open("Cyclic_sym_modal_freq_STICK.txt", "w") as txt_file:
        txt_file.write(
            "h\t" + "\t\t".join(
                [
                    "f" + str(i + 1) for i in range(len(data[0]))
                    ]
                ) + "\n"
            )
        
    h = 0
    with open("Cyclic_sym_modal_freq_STICK.txt", "a") as txt_file:
        for line in data:
            txt_file.write(str(h) + "\t" + "\t".join(line) + "\n")
            h += 1

    return sol_stick_cs
    
