import numpy as np
import math

import scipy.optimize
from scipy.signal import argrelextrema
import matplotlib.pyplot as plt

import funk
import time


class Contact:
    def __init__(self, Contact_dof_r, Contact_dof_l, Kt = 0, Kn = 0, mu = 0,
                 nH = 1, N0=0):
         
        if Kt == 0:
            self.Kt = 1e-15
        else:
            self.Kt = Kt
        
        if Kn == 0:
            self.Kn = 1e-15
        else:
            self.Kn = Kn
        
        self.mu = mu                            # friction coefficient
        self.N0 = N0                            # contact normal load
        self.Nstep = 100                        # step for hysteresis cycle
        self.nH = nH                            # number of harmonics
        
        self.Contact_dof_r = Contact_dof_r
        self.Contact_dof_l = Contact_dof_l
        self.Contact_dof = Contact_dof_r + Contact_dof_l
        
        # fsolve's fields
        self.final_tolerance = 1.e-6            # tolerance
        self.max_iterations = 100               # max iterations
        self.eps = 2**(-52)                     # machine precision

class Solution:
    def __init__(self, M, C, K, Fef, freq_range, con, reverse = 0, cs = False):
        

        txt = """\n MODAL ANALYSIS - LINEAR\n"""
        print(txt)
        
        self.M = M
        self.C = C
        self.K = K
        self.Fef = Fef

        self.il = len(self.Fef)
        
        #cyclic symmetry
        self.cs = cs
    
        start = time.time()

        txt = """Calculating slip linear solution...\t"""
        print(txt, end=' ')
        # Free solution (pure slip contact)
        self.x_slip, self.f_n, self.f = funk.modal(
            self.M, self.C, self.K, self.Fef, freq_range, reverse
            )
        
        txt = """done!"""
        print(txt)        
        
        if con.Kt == 0 and con.Kn == 0:
            pass
        else:
            # Adding contact stiffness to the free solution            
            K_stick = np.array([
                [
                    self.K[i][j] for j in range(len(self.K))
                ] for i in range(len(self.K))
            ])
            
            for cd in con.Contact_dof:
                # if cd is a tangential dof (x or y) add Kt
                if (cd + 1) % 3:
                    K_stick[cd][cd] += con.Kt
                # if cd is a normal dof (z) add Kn
                else:
                    K_stick[cd][cd] += con.Kn

            txt = """Calculating stick linear solution...\t"""
            print(txt, end=' ')
            # Stick solution (pure stick contact)
            self.x_stick, self.f_n_stick, self.f = funk.modal(
                self.M, self.C, K_stick, self.Fef, freq_range, reverse
                )
        
        end = time.time()

        self.f_n.sort()

        txt = """done!\nelapsed time: """
        print(txt, end - start, 'sec\n')
        txt = """\n natural frequencies (slip): """
        print(txt, self.f_n[:4])
        try:
            txt = """\n natural frequencies (stick): """
            print(txt, self.f_n_stick[:4])
        except:
            pass

    def nonlinear_solution(self, con, startstick = False, N = 1, h = 0):
        self.N = N
        self.startstick = startstick
        
        txt = """\n MODAL ANALYSIS - NON-LINEAR\n"""
        print(txt)

        start = time.time()

        omega = self.f * 2. * math.pi
        Ndof = len(self.Fef)
        X = 1j * np.zeros((Ndof, len(omega)))

        F = [
            self.Fef[i] for i in range(len(self.Fef))
        ] + [
            0. for i in range(len(self.Fef)*(con.nH-1))
        ]
        F = np.transpose(np.reshape(F, (con.nH, len(self.Fef))))
        
        #F = tuple(map(tuple, F))

        for nw in range(len(omega)):

            om = omega[nw]

            if nw == 0:
                if self.startstick == True:
                    # Guess value for the iterative solver.
                    # Linear response for the 1st frequency.
                    stick_solution = [self.x_stick[i][0] 
                                      for i in range(len(self.x_stick))
                                      ]
                    start_guess = stick_solution
                else:
                    start_guess = np.transpose(self.x_slip)[0]
                
                X0 = np.array(
                    [
                        start_guess[i]                        
                        if i < Ndof else 0. for i in range(con.nH * Ndof)
                    ]
                )
            # Previous response for the other frequencies.
            else:
                for i in range(len(X)):
                    X0[i] = X[i][nw - 1]
            # From complex to real numbers.
            x0 = tuple(funk.compl2real(X0))
            x1= scipy.optimize.fsolve(
                funk.dynamic_balance, x0, args=(
                    self, con, F, om, self.cs, self.N, h
                ), fprime=None,
                xtol=con.final_tolerance,
                maxfev=con.max_iterations
            )
            
            print('\r', nw * 100 // len(omega), ' %', end='')
            
            # From real to complex numbers.
            X_tmp = funk.real2compl(x1)
            for i in range(len(X)):
                X[i][nw] = X_tmp[i]

        self.X = tuple(X)

        end = time.time()
        txt = """\n------------\nDone!\nelapsed time: \t\t"""
        print(txt, (end - start)/60, 'min\n')
        if self.cs == True:
            txt = """Full analysis estimated \ntime remining: \t\t"""
            print(txt, (end - start) * (self.N//2 - h) / 60, 'min\n')

        return self.X

    def get_result(self, which_dof = 9999):
        if which_dof != 9999:
            self.out = which_dof

        txt = """-----------------\n ANALISYS RESULT\n\
-----------------\nNatural frequencies [Hz]:\t\t\t"""
        print(txt, self.f_n[:4])
        try:
            self.f_n_nl =  self.f[argrelextrema(
                abs(self.X[which_dof]), np.greater)
                ]
            self.f_n_nl.sort()
            
            self.A =  abs(self.X[which_dof][argrelextrema(
                abs(self.X[which_dof]), np.greater)]
                )
            txt = """\nNon-linear solution natural frequencies [Hz]:\t"""
            print(txt, self.f_n_nl[:6])
        except:
            pass
        
        return self.f_n_nl
        
    def plot_linear(self, which_dof = 9999):
        
        if which_dof != 9999:
            self.out = which_dof
        else:
            self.out = 0

        plt.figure(1)
        # plt.yscale('log')
        plt.grid(True)
        plt.grid(b=True, which='minor', linestyle='--')
        plt.suptitle('Frequency vs amplitude')
        plt.xlabel('Frequency [Hz]')
        plt.ylabel('Amplitude')
        
        # If dof == 9999 print all the dofs
        if which_dof == 9999:
            # Linear solution plot  
            for i in range(len(self.x_slip)):
                plt.plot(self.f, abs(self.x_slip[i]), 'b')
                plt.plot(self.f, abs(self.x_stick[i-1]), 'g', label='Stick, isolated')
                if i == 0:
                    plt.legend()            
        
        # if dof != 9999 plot only the indicated dof
        else:
            # Linear solution plot (without contact)
            plt.plot(self.f, abs(self.x_slip[which_dof]), 'b', label='Slip, isolated')
            try:
                plt.plot(self.f, abs(self.x_stick[which_dof]), 'g',  label='Stick, isolated')
            except:
                pass
        plt.legend()
        
        plt.figure(1)
        #plt.ylim([0, .01e-9])
        plt.savefig('blade_linear.png', dpi = 300)
        
    def plot_nonlinear(self, N0, which_dof=9999):
        plt.figure(1)
        lab = 'N0 = ' + str(N0)
        if which_dof == 9999:
            # Non-linear solution plot (if exist)
            try:
                for out in range(len(self.X)):
                    a = out
                    #a = self.out
                    plt.plot(self.f, abs(self.X[a]), label=lab, linewidth=0.5)
                    plt.grid(b=True, which='major', linestyle='--')
                    if out == 0:
                        plt.legend()
            except AttributeError:
                pass
        
        # if dof != 9999 plot only the indicated dof
        else:
            # Non-linear solution plot (with contact)
            try:
                plt.plot(self.f, abs(self.X[which_dof]), label=lab, linewidth=0.5)
                plt.grid(b=True, which='major', linestyle='--')
                plt.legend()
            except AttributeError:
                pass    
        #plt.ylim([0.1e-10, .15e-10])
        #plt.xlim([250, 300])
        #plt.yscale('log')
        plt.savefig('blade_nonlinear.png', dpi = 300)

    def plot_optimization(self, N0, f_n, A):
        plt.figure(2)
        plt.plot(N0, f_n, 'ro-')
        plt.grid(b=True, which='minor', linestyle='--')
        plt.suptitle('Normal load vs natural frequency')
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('N0')
        plt.savefig('blade_N0_fn.png', dpi = 300)

        plt.figure(3)
        plt.plot(N0, A, 'ro-')
        plt.grid(b=True, which='minor', linestyle='--')
        plt.suptitle('Normal load vs amplitude')
        plt.ylabel('Amplitude')
        plt.xlabel('N0')
        plt.savefig('blade_N0_A.png', dpi = 300)
        
    def plot_cyclic(self, h, num_of_freq = 3):
        
        plt.figure(3)
        for i in range(num_of_freq):
            try:
                if i == 0:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], 'ro', label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i], 'ro')
                elif i == 1:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], 'bo', label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i], 'bo')
                elif i == 2:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], 'go', label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i], 'go')
                elif i == 3:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], 'ko', label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i], 'ko')
                elif i == 4:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], 'yo', label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i], 'yo')
                else:
                    if h == 0:
                        lab = 'Famiglia #' + str(i+1)
                        plt.plot(h, self.f_n_nl[i], label = lab)
                    else:
                        plt.plot(h, self.f_n_nl[i])
            except:
                pass
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15),
          fancybox=True, shadow=0, ncol=3)
        #plt.ylim([0, 1400])
        plt.grid(b=True, which='minor', linestyle='--')
        plt.xticks(np.arange(0, h + 1, 1))
        plt.xlabel('Indice armonico - h')
        plt.ylabel('f [Hz]')
        plt.savefig('blade_cyclic.png', dpi = 300)


    
