'''
04 11 2025 Yang Haifeng
'''
import numpy as np
import serial
import time
from collections import deque
from scipy.ndimage import median_filter
import traceback
import threading
from datetime import datetime
from scipy.linalg import svd
from scipy.linalg import eig
from numpy.linalg import pinv
from scipy.linalg import hankel
from scipy import signal
import re
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import toeplitz
from scipy.fft import fft, ifft

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd

import numpy as np

from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import pandas as pd
from sklearn.linear_model import ElasticNet
from sklearn.linear_model import RidgeCV, LassoCV, ElasticNetCV
from sklearn.model_selection import cross_val_score


SERIAL_PORT='COM3'
BAUDRATE=921600
NUM_ROWS,NUM_COLS=8,8
FPS=60#30#15  # 60
spid=deque(maxlen=FPS * 20)  #
data_buffer=deque(maxlen=FPS * 20)  # buffer
running=True
log2=[]
log3=[]
log4=[]
log5=[]
log6=[]
timestp=[]
paralist=deque(maxlen=50)
start_time=None
background=[]
foreground=[]

def ridge_regre(tof_signal, window_size=55, lambda_val=200):
    from sklearn.preprocessing import StandardScaler

    X = []
    y = []

    for i in range(len(tof_signal) - window_size):
        window = tof_signal[i:i + window_size]
        features = [np.ptp(window), np.std(window), np.max(window), np.min(window)]
        X.append(features)
        y.append(np.ptp(window))

    X = np.array(X)
    y = np.array(y)
    
    ridge = Ridge(alpha=lambda_val)
    ridge.fit(X, y)
    predictions = ridge.predict(X)
    return predictions, y
def serial_port(ser):
    buffer=""
    while True:
        if ser.in_waiting:
            raw_data=ser.read(ser.in_waiting)
            buffer += raw_data.decode('utf-8',errors='replace')
            lines=buffer.split('\n')
            buffer=lines[-1]
            for line in lines[:-1]:
                line=line.strip()
                if line.startswith('RangingData'):
                    ser_Ranging=read(line)
                    data_buffer.append(ser_Ranging)
                if line.startswith('signal_per_spad'):
                    ser_spad=read_spid(line)
                    spid.append(ser_spad)
def read(line):
    temp=np.zeros((NUM_ROWS,NUM_COLS))
    for idx,val in re.findall(r'RangingData(\d+)\s*=\s*(\d+)',line):
        index=int(idx)
        row,col=index // NUM_COLS,index % NUM_COLS
        temp[row,col]=val

    return temp
def read_spid(line):
    temp=np.zeros((NUM_ROWS,NUM_COLS))
    for idx,val in re.findall(r'signal_per_spad(\d+)\s*=\s*(\d+)',line):
        index=int(idx)
        row,col=index // NUM_COLS,index % NUM_COLS
        temp[row,col]=val
    return temp
    
def corr_back(in_matrix,N=300):
    global pre_signal
    in_matrix=np.array(in_matrix)

    num_frames,rows,cols=in_matrix.shape
    values=[]
    w=[]
    for i in range(1,rows-1):
        for j in range(1,cols-1):
            center=in_matrix[:,i,j]
            pos=[ # fancy index
                (i-1,j-1),(i-1,j),(i-1,j+1),
                (i,j-1),       (i,j+1),
                (i+1,j-1),(i+1,j),(i+1,j+1)
            ]
            corr_n=[]
            for ni,nj in pos:
                nei=in_matrix[:,ni,nj]
                if not np.any(np.isnan(center[-N:])) and not np.any(np.isnan(nei[-N:])):
                # all the center and the nei should not with nan 
                    corr=pearson_correlation(center[-N:],nei[-N:])
                    corr_abs=abs(corr)
                    corr_n.append(corr_abs)
            if len(corr_n)>0:
                avg_corr=np.mean(corr_n)
                values.append(center[-N:] * avg_corr)#do not drop coefficients
                # values=in_matrix[i,j,:]
                #w.append(avg_corr)
    values=np.array(values)
    length=len(values)
    if length <= 3:
        if len(pre_signal) > 0:
            s_out=pre_signal
        else:
            s_out=np.mean(values,axis=0)
            pre_signal=s_out
    else:
        s_out=np.mean(values,axis=0)
        pre_signal=s_out
    return s_out
def pearson_correlation(x,y):
    x=np.array(x)
    y=np.array(y)
    x_flat=x.flatten()# turn two dimension to one dimension
    y_flat=y.flatten()
    #(x-E(x))*(y-E(y))
    num=np.sum((x_flat-np.mean(x_flat)) * (y_flat-np.mean(y_flat)))
    #(x-E(x))^2*(y-E(y))^2
    denom=np.sqrt(np.sum((x_flat-np.mean(x_flat))**2) * np.sum((y_flat-np.mean(y_flat))**2))
    if denom != 0:
        return num / denom 
    else:
        return 0

def burg_method(x,order):
    N=len(x)
    x=np.asarray(x,dtype=np.float64)
    forward=x.copy()#=x # only the version of copy is modified
    backward=x.copy()#=x
    AK=np.zeros(order+1)
    AK[0]=1.0
    error=np.dot(x,x)/N
    for k in range(1,order+1):
        num=-2.0*np.dot(forward[k:],backward[k-1:-1])
        den=np.dot(forward[k:],forward[k:])+np.dot(backward[k-1:-1],backward[k-1:-1])
        #one step prediction 
        if np.abs(den) < 1e-12:
            gamma=0
        else:
            gamma=num / den
        a_prev=AK[:k+1]
        
        for i in range(1,k):
            AK[i]=a_prev[i]+gamma*a_prev[k-i]#update the coefficients AK 
            
        AK[k]=gamma
        f_temp=forward.copy()#f_temp=forward
        b_temp=backward.copy()#b_temp=backward
        
        for n in range(k,N):
            forward[n]=f_temp[n]+gamma*b_temp[n-1]
            backward[n]=b_temp[n-1]+gamma*f_temp[n]
            
        error *= (1-gamma**2)#* *2 square
    return AK,error

def burg_spectrum(x,order=8,res=512,fs=5.0):
    global paralist
    x=np.asarray(x,dtype=np.float64)
    para,E=burg_method(x,order)

    w=np.linspace(0,np.pi,res)
    z=np.exp(-1j * w)
    denom= np.zeros(res,dtype=np.complex128)#datatype complex128 
    
    for i in range(len(z)):
        zi=z[i]
        denom[i]=np.polyval(para[::-1],zi)
    pow =(
    E #enegry divided by polynomial
    / 
    np.abs(denom) ** 2
    )
    freqs =(
    w * fs 
    / 
    (2 * np.pi)
    )
    return freqs,pow
def esprit(s_in,fs,n_s):
    H=h_m(s_in) # s is observation
    U,_,_=svd(H) # singular decomposition
    U1=U[:-1,:n_s]# u is input matrix
    U2=U[1:,:n_s]
    Phi =pinv(U1) @ U2 #pseudo inverse,@ matrix multiply
    eig1=eig(Phi)[0]# obtain eigen value,complex number,decending order
    freqs=(
    np.angle(eig1)*fs
    / (2 * np.pi)
    )
    return np.abs(freqs)
def h_m(s_in,L=None):
    N=len(s_in)
    if L is None:
        L=N // 2
    M=N-L+1
    return hankel(s_in[:L],s_in[L-1:L-1+M])

def analysis():
    global background,foreground,start_time
    plt.ion()
    fig,axs=plt.subplots(ncols=2,nrows=2,figsize=(7,5))#left right
    ax1,ax2,ax3,ax4=axs.flatten()[:4]
    # initialize
    im1=ax1.imshow(np.zeros((8,8)),cmap='viridis')
    cbar1=fig.colorbar(im1,ax=ax1)
    cbar1.set_label('Distance')
    ax1.set_title('8x8 Spatial Distance')
    ax1.set_xticks(range(8))
    ax1.set_yticks(range(8))

    # initialize grid
    distance=[[None for _ in range(8)] for _ in range(8)]
    for i in range(8):
        for j in range(8):
            distance[i][j]=ax1.text(j,i,'',
                                         ha='center',va='center',
                                         color='white',fontsize=8,
                                         fontweight='bold')


    # initialize spectrum diagram
    line2,=ax2.plot([None],[None],'b-',linewidth=2)
    line2_2, = ax2.plot([], [], 'r-', linewidth=2, label='Data 2')  # 新增第二条曲线
    ax2.set_xlabel('Frequency (Hz)')
    ax2.set_ylabel('Power Spectral Density')
    ax2.set_title('Frequency Spectrum')
    ax2.grid(True,alpha=0.3)
    ax2.set_xlim(0,0.8)

    line3,=ax3.plot([],[],'b-',linewidth=2)
    line3_2, = ax3.plot([], [], 'r-', linewidth=2, label='Data 2')  # 新增第二条曲线
    ax3.set_xlabel('Time')
    ax3.set_ylabel('Distance')
    ax3.set_title('Time domain')
    ax3.grid(True,alpha=0.3)

    line4,=ax4.plot([],[],'b-',linewidth=2)
    ax4.set_xlabel('Time')
    ax4.set_ylabel('Distanceb')
    ax4.set_title('Time domain')
    ax4.grid(True,alpha=0.3)
    while running:
        frames=list(data_buffer)
        sigma = list(spid)
        if len(frames) < (FPS * 2):
            continue
        if len(sigma) < (FPS * 2):
            continue

        pixel_series=np.zeros((len(frames),NUM_ROWS,NUM_COLS))
        f_denoise=np.array(frames)
        ini_frames=np.array(frames)
        detr=np.zeros_like(pixel_series)
        for i in range(NUM_ROWS):
            for j in range(NUM_COLS):
                med=median_filter(f_denoise[:,i,j],size=5)
                detr[:,i,j]=signal.detrend(med)
        ##################
    ######################

        f_sigma=np.array(sigma)
        f_denoise=np.array(frames)
        min_length=min(len(f_sigma),len(f_denoise))
        f_sigma2=f_sigma[:min_length]
        f_denoise2=f_denoise[:min_length]
        i_denoise=f_denoise
        f_denoise=(50-f_sigma2) / 50 * f_denoise2

        mean=f_denoise.mean()
        N,x,y=f_denoise.shape
        background=np.zeros((x,y),dtype=bool)
        foreground=np.zeros((x,y),dtype=bool)
        for i in range(x):
            for j in range(y):
                #if f_denoise[-1,i,j] > mean:
                if ini_frames[-1, i, j] > 400 and ini_frames[-1, i, j] <650 :
                    foreground[i,j]=True
                if ini_frames[-1, i, j] > 650:
                    background[i,j]=True
        foreground_pixels=np.full_like(ini_frames,np.nan,dtype=float)
        background_pixels = np.full_like(ini_frames, np.nan, dtype=float)
        fgmask = np.broadcast_to(foreground, (N,x,y))
        bkmask = np.broadcast_to(background, (N, x, y))
        foreground_pixels[fgmask] = ini_frames[fgmask]
        background_pixels[bkmask] = ini_frames[bkmask]

        no_pearson=ini_frames[:,6,5]
        med = median_filter(no_pearson,size=5)
        no_pearson=signal.detrend(med)

        detr2 = corr_back(foreground_pixels)

        foreground_pixels=ini_frames
        med = median_filter(detr2, size=5)
        detr2=signal.detrend(med)
        no_pearson =detr2


        backg_time, true_amplitudes = ridge_regre(detr2)
        pad_length = len(detr2) - len(backg_time)
        detr2 = np.pad(backg_time, (0, pad_length), mode='edge')

        detr2 = signal.detrend(detr2)
        backg_time = detr2
        #backg_time = corr_back(background_pixels)
        #backg_time = median_filter(backg_time,size=5)
        #backg_time=signal.detrend(med)

        #detr2= restored_signals[-1]
        #detr2=corr_back(foreground_pixels)
        region=detr[:,2:6,2:6] #center pixels
        region=np.array(region)
        backg_time=np.array(backg_time)
        if len(backg_time) > 32:
            order32=32
        else:
            order32=len(backg_time)-1 # adjust number of order

        if len(backg_time) > 20:
            order20=20
        else:
            order20=len(backg_time)-1

        freqs2,psd2=burg_spectrum(backg_time,order20,256,1.6)
        freqs,psd  =burg_spectrum(backg_time,order32,256,1.6)
        #snr_analysis(freqs,psd)
        #freqs=freqs * 1.4 * 5

        #error = backg_time - detr2
        #freqs, psd = burg_spectrum(error, order32, 256, 1.6)

        ######### grid distance diagram ########
        mean=ini_frames.mean() # ini_frames raw data
        N,x1,y1=ini_frames.shape
        for i in range(x1):
            for j in range(y1):
                if ini_frames[-1,i,j] > mean:
                    ini_frames[:,i,j]=None #hide pixels of background
        latest_frame= foreground_pixels[-1]
        #latest_frame=ini_frames[-1]
        #latest_frame= f_sigma[-1] # standard deviation
        im1.set_data(latest_frame)
        im1.set_clim(vmin=0,vmax=700)#range 0-700mm
        for i in range(8):
            for j in range(8):
                distance[i][j].set_text(f'{latest_frame[i,j]:.1f}')
                distance[i][j].set_color('white')

        ######### background distance diagram ########
        mean=ini_frames.mean() # ini_frames raw data
        N,x1,y1=ini_frames.shape
        for i in range(x1):
            for j in range(y1):
                if ini_frames[-1,i,j] > mean:
                    ini_frames[:,i,j]=None #hide pixels of background
        latest_frame= foreground_pixels[-1]
        #latest_frame=ini_frames[-1]
        #latest_frame= f_sigma[-1] # standard deviation
        im1.set_data(latest_frame)
        im1.set_clim(vmin=0,vmax=700)#range 0-700mm
        for i in range(8):
            for j in range(8):
                distance[i][j].set_text(f'{latest_frame[i,j]:.1f}')
                distance[i][j].set_color('white')

        freqsb, psdb = burg_spectrum(backg_time, order32, 256, 1.6)

        #power spetral density diagram
        line2.set_data(freqs,psd)
        line2_2.set_data(freqs, psdb)
        all_data = np.concatenate([psd, psdb])
        ax2.set_xlim(0,0.8)# 0hz - 0.8hz range
        ax2.set_ylim(all_data.min(), all_data.max())



        # time domain diagram
        time_max = 20
        x_axes = np.linspace(0, time_max, len(detr2))
        line3.set_data(x_axes, detr2)

        if  np.any(np.isnan(no_pearson[-N:])):
            all_data = np.concatenate([detr2])
        else:
            line3_2.set_data(x_axes, no_pearson)  # backg_time)  # no_pearson 红色
            all_data = np.concatenate([detr2, no_pearson])#backg_time])
        #all_data = np.concatenate([detr2])
        ax3.set_xlim(0, time_max)
        ax3.set_ylim(all_data.min(), all_data.max())


        # time domain diagram
        time_max=20
        x_axes =np.linspace(0, time_max, len(backg_time))# np.arange(len(detr2))
        line4.set_data(x_axes, backg_time)
        ax4.set_xlim(0, time_max)
        ax4.set_ylim(backg_time.min(), backg_time.max())

        ####################
        plt.tight_layout()
        plt.draw()
        plt.pause(0.01)
        #######################
        idx=np.argmax(psd)
        idx2=np.argmax(psd2)
        print(f"                                   breath frequency:{freqs[idx]:.2f} hz")#,end=""
        
        freqs_e=esprit(backg_time,1.6,2)#ESPRIT method estimate frequency
        
        
        current_time=time.time()
        if start_time is None:
            start_time=current_time

        inteval=current_time-start_time
        mins=int(inteval // 60)
        secs=int(inteval % 60)
        time1=f"{mins}:{secs:02d}\n"
        log2.append(freqs2[idx]) # log data
        log4.append(freqs[idx2])
        log3.append(freqs_e[-1])
        
        log5.extend(detr2)
        log5.extend(backg_time)
        log6.extend(no_pearson)
        timestamp=time1
        timestp.append(timestamp)
        #if len(log2) == 1:
        #    print(f"                       start.")
        if len(log2) % 10 == 0:  # log data less frequently
            np.savetxt("timestamp.txt",[timestp],fmt='%s')
            np.savetxt("auto.txt",log2,fmt="%.4f")
            np.savetxt("auto32order.txt",log4,fmt="%.4f")
            np.savetxt("esprit.txt",log3,fmt="%.4f")
            np.savetxt("clean.txt",log5,fmt="%.4f")
            np.savetxt("orgtime.txt",log6,fmt="%.4f")
            print(f"                       Saved at iteration {len(log2)}")
    
def snr_analysis(f,psd):
    peak=f[np.argmax(psd)]# find max frequency

    band=0.3  # Hz
    mask =(#0.05hz band pass
    (f >peak- band/2)
    &
    (f <peak+ band/2)
    )

    noise=np.trapz(psd,f)-np.trapz(psd[mask],f[mask])
    # integration using the trapezoidal
    snr_db =(
    10 * np.log10( np.trapz(psd[mask],f[mask])
    /
    noise
    )
    )
    print(f"                        SNR: {snr_db:.2f} dB")

def metronome():
    bpm =15
    bpm2=0
    bpm_cnt=0
    while running:

        inhale_time=60.0/bpm*0.4;
        exhale_time=60.0/bpm*0.6;
        if bpm != bpm2:
            print(f"phase {(bpm-15)/5+1}: {bpm} BPM")
            bpm2=bpm
        print(f"   < Inhale")
        time.sleep(inhale_time)
        print(f"   > Exhale")
        time.sleep(exhale_time)
        bpm_cnt=bpm_cnt+1
        if bpm_cnt==bpm:
            bpm=bpm+5
            bpm_cnt=0
        if bpm==30:
            print(f"Finish...")
            while True:
                time.sleep(1)
def file_thread():
    global latest_frame, running
    buffer = ""
    line_count = 0
    try:
        with open('64Res1Meter2025-11-07 084454.txt', 'r', encoding='utf-8') as file:
        #with open('6分钟30cm.txt')as file:
            st = time.time()
            while running:

                line = file.readline()
                if not line:  
                    stop = time.time()
                    ttt=stop-st
                    print("elapsed time:", ttt, "s")
                    print({line_count})
                    while True:
                        time.sleep(1)
                        print("elapsed time:", ttt, "s")
                        print({line_count})
                    file.seek(0) 
                    line = file.readline() 
                    if not line:  
                        break

                line = line.strip()
                line_count += 1
                line_count_2=5
                if line_count%line_count_2==0 or  line_count%line_count_2==1 :
                    if line.startswith('RangingData'):
                        data = read(line)
                        if data is not None:
                            latest_frame = data
                            data_buffer.append(data)
                    if line.startswith('signal_per_spad'):
                        ser_spad = read_spid(line)
                        spid.append(ser_spad)
                time.sleep(0.01)  

    except FileNotFoundError:
        print("cannot find '64Res1Meter2025-11-07 084454.txt'")
    except Exception as e:
        print(f"error {e}")
        traceback.print_exc()
def main():
    threading.Thread(
        target=file_thread,
        daemon=True
    ).start()
    '''
    ser=serial.Serial(SERIAL_PORT,BAUDRATE,timeout=0.1)
    threading.Thread(
        target=serial_port,
        args=(ser,),
        daemon=True
    ).start()
    '''
    threading.Thread(
        target=metronome,
        daemon=True
    ).start()
    analysis()

if __name__ == "__main__":
    main()