#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Analisi HRV su ECG (riposo e cyclette) con salvataggi automatici
Autore: Annarosa Scalcione
Data: 2025-09-15

Funzionalità:
1) Parsing dei log .txt ECG (3 derivazioni, blocchi da 53 campioni) e salvataggio in .npz
2) Ricostruzione dell'asse temporale campione-per-campione (interpolazione tra timestamp di blocco)
3) Correzione automatica della polarità delle derivazioni, se necessario
4) Calcolo di metriche HRV (time, frequency + alcune non lineari) su:
   - intero segnale di riposo
   - finestre definite del protocollo cyclette (senza/ con resistenza)
5) Valutazione qualità segnale per derivazione
6) Salvataggio dei risultati in un file Excel e dei grafici in PNG
"""

# ======================================================
# LIBRERIE
# ======================================================
import numpy as np
import re
import matplotlib.pyplot as plt
import os
import neurokit2 as nk
import pandas as pd
from datetime import datetime

# Opzioni display pandas (solo per stampa a schermo)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)


# ======================================================
# FUNZIONI DI SUPPORTO
# ======================================================

def parse_txt_to_npz(txt_path, npz_path):
    """
    Parsing del file .txt ECG e salvataggio in .npz.

    Formato atteso per riga:
      HH:MM:SS.xxxxxx [[<53 campioni ch1>] [<53 campioni ch2>] [<53 campioni ch3>]]
    """
    pattern = r"(\d{2}:\d{2}:\d{2}\.\d+)\s+\[\[(.*?)\]\s+\[(.*?)\]\s+\[(.*?)\]\]"
    with open(txt_path, "r") as file:
        text = file.read()
    matches = re.findall(pattern, text, re.DOTALL)

    timestamps, ecg1_list, ecg2_list, ecg3_list = [], [], [], []
    for ts, e1, e2, e3 in matches:
        try:
            ecg1 = np.fromstring(e1, sep=" ")
            ecg2 = np.fromstring(e2, sep=" ")
            ecg3 = np.fromstring(e3, sep=" ")
            if len(ecg1) == len(ecg2) == len(ecg3):
                ecg1_list.append(ecg1)
                ecg2_list.append(ecg2)
                ecg3_list.append(ecg3)
                timestamps.append(ts)
        except:
            print(f"Errore parsing blocco dopo {ts}")

    np.savez(npz_path,
             timestamps=np.array(timestamps),
             ecg1=np.array(ecg1_list),
             ecg2=np.array(ecg2_list),
             ecg3=np.array(ecg3_list))
    print(f"File salvato: {npz_path}")


def load_ecg_data(filename):
    """
    Carica i dati ECG da file .npz.

    Returns
    -------
    ecg1, ecg2, ecg3 : np.ndarray (blocchi x 53)
    timestamps       : np.ndarray di stringhe "HH:MM:SS.xxxxxx"
    fs               : int (frequenza campionamento, 1000 Hz)
    """
    data = np.load(filename, allow_pickle=True)
    return data["ecg1"], data["ecg2"], data["ecg3"], data["timestamps"], 1000  # fs


def reconstruct_time(signal_array, timestamps, fs, n_samples=53):
    """
    Ricostruisce l'asse temporale campione-per-campione con interpolazione lineare
    tra i timestamp di blocco. Restituisce anche il segnale appiattito.

    Parametri
    ---------
    signal_array : np.ndarray (blocchi x n_samples)
    timestamps   : np.ndarray di stringhe "HH:MM:SS.xxxxxx"
    fs           : int, frequenza campionamento (Hz)
    n_samples    : int, campioni per blocco (default 53)

    Returns
    -------
    flat : np.ndarray, segnale 1D
    t_full : np.ndarray, tempi (s) per ogni campione di flat
    """
    rel_time = []
    for ts in timestamps:
        dt = datetime.strptime(str(ts), "%H:%M:%S.%f")
        sec = dt.hour * 3600 + dt.minute * 60 + dt.second + dt.microsecond / 1e6
        rel_time.append(sec)
    rel_time = np.array(rel_time) - rel_time[0]

    flat, t_full = [], []
    for t0, t1, block in zip(rel_time[:-1], rel_time[1:], signal_array[:-1]):
        block_times = np.linspace(t0, t1, n_samples, endpoint=False)
        t_full.extend(block_times)
        flat.extend(block)
    # Ultimo blocco: tempi uniformi a partire dall'ultimo timestamp
    last_block = signal_array[-1]
    last_times = np.linspace(rel_time[-1], rel_time[-1] + (n_samples - 1) / fs, n_samples)
    t_full.extend(last_times)
    flat.extend(last_block)

    return np.array(flat), np.array(t_full)


def auto_invert_if_needed(signal, fs, label="ECG", debug=True):
    """
    Inverte automaticamente il segnale se i complessi QRS risultano prevalentemente negativi.

    Strategia:
      - Pulizia e band-pass 5–20 Hz per evidenziare QRS
      - Rilevo picchi su |QRS| (indipendente dalla polarità)
      - Se frazione di picchi positivi < 40% -> inverte
    """
    ecg_clean = nk.ecg_clean(signal, sampling_rate=fs)
    ecg_qrs = nk.signal_filter(ecg_clean, sampling_rate=fs,
                               lowcut=5, highcut=20, method="butterworth", order=3)
    _, rpeaks = nk.ecg_peaks(np.abs(ecg_qrs), sampling_rate=fs)
    locs = rpeaks.get("ECG_R_Peaks", np.array([], dtype=int))

    if len(locs) == 0:
        return signal

    vals = ecg_qrs[locs]
    frac_pos = np.mean(vals > 0)
    if frac_pos < 0.40:
        return -signal
    return signal


def compute_hrv_metrics(ecg, fs, label="ECG"):
    """
    Calcola metrica HRV (time/frequency + alcune non lineari) sull'intero segnale.
    Restituisce un dict con le feature principali.
    """
    ecg_clean = nk.ecg_clean(ecg, sampling_rate=fs)
    _, info = nk.ecg_process(ecg_clean, sampling_rate=fs)

    hrv_time = nk.hrv_time(info, sampling_rate=fs)
    hrv_freq = nk.hrv_frequency(info, sampling_rate=fs)

    meanNN = hrv_time["HRV_MeanNN"].values[0]
    rmssd = hrv_time["HRV_RMSSD"].values[0]

    # Derivazione NN intervals dai R-peaks
    rpeaks = info.get("ECG_R_Peaks", None)
    if rpeaks is not None and len(rpeaks) > 1:
        nn_intervals = np.diff(rpeaks) * 1000 / fs  # ms
    else:
        nn_intervals = np.array([])

    if len(nn_intervals) > 1:
        nn_diff = np.diff(nn_intervals)
        SDSD = np.std(nn_diff)
    else:
        SDSD = np.nan

    CVNN = hrv_time["HRV_SDNN"].values[0] / meanNN if meanNN > 0 else np.nan
    CVSD = SDSD / meanNN if meanNN > 0 else np.nan

    SD1 = hrv_time["HRV_SD1"].values[0] if "HRV_SD1" in hrv_time.columns else rmssd / np.sqrt(2)
    SD2 = hrv_time["HRV_SD2"].values[0] if "HRV_SD2" in hrv_time.columns else np.sqrt(
        max(0, 2 * hrv_time["HRV_SDNN"].values[0]**2 - 0.5 * rmssd**2)
    )
    SD1_SD2_ratio = SD1 / SD2 if SD2 > 0 else np.nan

    return {
        "Derivazione": label,
        "MeanNN": meanNN,
        "SDNN": hrv_time["HRV_SDNN"].values[0],
        "RMSSD": rmssd,
        "pNN50": hrv_time["HRV_pNN50"].values[0],
        "LF": hrv_freq["HRV_LF"].values[0],
        "HF": hrv_freq["HRV_HF"].values[0],
        "LF_norm": hrv_freq["HRV_LF"].values[0] / (hrv_freq["HRV_LF"].values[0] + hrv_freq["HRV_HF"].values[0] + 1e-8),
        "HF_norm": hrv_freq["HRV_HF"].values[0] / (hrv_freq["HRV_LF"].values[0] + hrv_freq["HRV_HF"].values[0] + 1e-8),
        "LF_HF_ratio": hrv_freq["HRV_LFHF"].values[0],
        # Non lineari/derivati
        "SDSD": SDSD,
        "CVNN": CVNN,
        "CVSD": CVSD,
        "SD1": SD1,
        "SD2": SD2,
        "SD1_SD2_ratio": SD1_SD2_ratio,
    }


def compute_hrv_on_window(ecg, fs, t_full, start_sec, duration_sec, label="ECG"):
    """
    Calcola HRV in una finestra temporale definita su [start_sec, start_sec + duration_sec).
    Restituisce un dict con le feature oppure None se finestra troppo corta (< 30s).
    """
    mask = (t_full >= start_sec) & (t_full < start_sec + duration_sec)
    seg = ecg[mask]
    if len(seg) < fs * 30:
        return None

    ecg_clean = nk.ecg_clean(seg, sampling_rate=fs)
    _, info = nk.ecg_process(ecg_clean, sampling_rate=fs)

    hrv_time = nk.hrv_time(info, sampling_rate=fs)
    hrv_freq = nk.hrv_frequency(info, sampling_rate=fs)

    meanNN = hrv_time["HRV_MeanNN"].values[0]
    rmssd = hrv_time["HRV_RMSSD"].values[0]

    # NN intervals
    rpeaks = info.get("ECG_R_Peaks", None)
    if rpeaks is not None and len(rpeaks) > 1:
        nn_intervals = np.diff(rpeaks) * 1000 / fs  # ms
    else:
        nn_intervals = np.array([])

    if len(nn_intervals) > 1:
        nn_diff = np.diff(nn_intervals)
        SDSD = np.std(nn_diff)
    else:
        SDSD = np.nan

    CVNN = hrv_time["HRV_SDNN"].values[0] / meanNN if meanNN > 0 else np.nan
    CVSD = SDSD / meanNN if meanNN > 0 else np.nan

    SD1 = hrv_time["HRV_SD1"].values[0] if "HRV_SD1" in hrv_time.columns else rmssd / np.sqrt(2)
    SD2 = hrv_time["HRV_SD2"].values[0] if "HRV_SD2" in hrv_time.columns else np.sqrt(
        max(0, 2 * hrv_time["HRV_SDNN"].values[0]**2 - 0.5 * rmssd**2)
    )
    SD1_SD2_ratio = SD1 / SD2 if SD2 > 0 else np.nan

    return {
        "Derivazione": label,
        "Fase": label.split(" - ")[-1],
        "MeanNN": meanNN,
        "SDNN": hrv_time["HRV_SDNN"].values[0],
        "RMSSD": rmssd,
        "pNN50": hrv_time["HRV_pNN50"].values[0],
        "LF": hrv_freq["HRV_LF"].values[0],
        "HF": hrv_freq["HRV_HF"].values[0],
        "LF_norm": hrv_freq["HRV_LF"].values[0] / (hrv_freq["HRV_LF"].values[0] + hrv_freq["HRV_HF"].values[0] + 1e-8),
        "HF_norm": hrv_freq["HRV_HF"].values[0] / (hrv_freq["HRV_LF"].values[0] + hrv_freq["HRV_HF"].values[0] + 1e-8),
        "LF_HF_ratio": hrv_freq["HRV_LFHF"].values[0],
        # Non lineari/derivati
        "SDSD": SDSD,
        "CVNN": CVNN,
        "CVSD": CVSD,
        "SD1": SD1,
        "SD2": SD2,
        "SD1_SD2_ratio": SD1_SD2_ratio,
    }


def evaluate_signal_quality(ecg, fs, label="ECG", debug=True):
    """
    Valuta la qualità del segnale su una derivazione (affidabile / non affidabile).

    Criteri:
      - numero sufficiente di R-peaks
      - HR medio plausibile e stabilità (deviazione standard)
      - proporzione di RR fuori range
      - SNR sul band-pass QRS
      - ampiezza media dei R-peaks
    """
    # Pulizia
    ecg_clean = nk.ecg_clean(ecg, sampling_rate=fs)

    # R-peaks
    _, rpeaks = nk.ecg_peaks(ecg_clean, sampling_rate=fs)
    locs = np.array(rpeaks.get("ECG_R_Peaks", []), dtype=int)

    # 1) Numero minimo di battiti (>= 20 bpm equivalenti sulla durata)
    durata_min = max(1e-9, len(ecg) / fs / 60.0)
    if len(locs) < durata_min * 20:
        if debug:
            print(f"{label}: troppo pochi R-peaks ({len(locs)})")
        return False

    # 2) HR medio e stabilità
    if len(locs) > 1:
        hr = nk.ecg_rate(locs, sampling_rate=fs)
        hr_mean, hr_std = float(np.nanmean(hr)), float(np.nanstd(hr))
    else:
        hr_mean, hr_std = np.nan, np.nan

    if not (35 < hr_mean < 210):
        if debug:
            print(f"{label}: HR medio fuori range ({hr_mean:.1f} bpm)")
        return False
    if hr_std > 25:
        if debug:
            print(f"{label}: HR instabile (std={hr_std:.1f})")
        return False

    # 3) RR plausibili
    if len(locs) > 1:
        rr = np.diff(locs) * 1000.0 / fs  # ms
        bad_rr = np.sum((rr < 350) | (rr > 1800)) / len(rr)
    else:
        bad_rr = 1.0
    if bad_rr > 0.3:
        if debug:
            print(f"{label}: troppi intervalli RR anomali ({bad_rr*100:.1f}%)")
        return False

    # 4) SNR nel band-pass QRS
    ecg_qrs = nk.signal_filter(ecg_clean, sampling_rate=fs,
                               lowcut=5, highcut=20, method="butterworth", order=3)
    power_qrs = float(np.mean(ecg_qrs**2))
    power_tot = float(np.mean(ecg_clean**2))
    noise = max(power_tot - power_qrs, 1e-8)
    snr = power_qrs / noise
    if snr < 0.1:
        if debug:
            print(f"{label}: SNR basso ({snr:.2f})")
        return False

    # 5) Ampiezza media R-peaks
    r_values = ecg_clean[locs] if len(locs) > 0 else np.array([])
    if len(r_values) == 0 or np.mean(np.abs(r_values)) < 0.1 * np.std(ecg_clean):
        if debug:
            print(f"{label}: R-peaks troppo piccoli o indistinti")
        return False

    if debug:
        print(f"{label}: affidabile (HR={hr_mean:.1f}±{hr_std:.1f}, RR_ok={100*(1-bad_rr):.1f}%, SNR={snr:.2f})")

    return True


def plot_filtered_window_multideriv(ecg_list, labels, fs, t_full,
                                    start_sec, duration_sec,
                                    title_phase, subject_number, save_dir):
    """
    Grafico dei segnali filtrati + R-peaks su 3 derivazioni, in finestra definita.
    Salva il PNG in save_dir.
    """
    mask = (t_full >= start_sec) & (t_full < start_sec + duration_sec)
    tseg = t_full[mask]

    plt.figure(figsize=(15, 8))
    for i, (ecg, label) in enumerate(zip(ecg_list, labels), start=1):
        seg = ecg[mask]
        ecg_clean = nk.ecg_clean(seg, sampling_rate=fs)
        _, rpeaks = nk.ecg_peaks(ecg_clean, sampling_rate=fs)
        locs = rpeaks.get("ECG_R_Peaks", [])

        plt.subplot(3, 1, i)
        plt.plot(tseg, ecg_clean, label=f"{label}")
        if len(locs) > 0:
            plt.scatter(tseg[locs], ecg_clean[locs], color="red", s=15, label="R-peaks")
        plt.title(f"{label} — {title_phase} (Soggetto {subject_number})")
        plt.ylabel("Ampiezza (a.u.)")
        plt.xlabel("Tempo [s]")
        plt.grid()
        plt.legend()

    plt.tight_layout()
    filename = f"ECG_{title_phase.replace(' ', '_')}_Soggetto_{subject_number}.png"
    plt.savefig(os.path.join(save_dir, filename), dpi=300)
    plt.close()


# ======================================================
# ESECUZIONE
# ======================================================

base_dir = "/Users/annarosascalcione/Desktop/università/Magistrale/Tesi Anna/codici/soggetti"
subject_number = input("Inserisci il numero del soggetto: ").strip()
output_dir = os.path.join(base_dir, f"Soggetto_{subject_number}")
os.makedirs(output_dir, exist_ok=True)

# -------------------- RIPOSO --------------------
txt_path = os.path.join(base_dir, f"ECG_LOG_riposo_{subject_number}.txt")
npz_path = os.path.join(base_dir, f"ECG_LOG_riposo_{subject_number}.npz")
parse_txt_to_npz(txt_path, npz_path)
ecg1_array, ecg2_array, ecg3_array, timestamps, fs = load_ecg_data(npz_path)

flat_ecg1, t_full = reconstruct_time(ecg1_array, timestamps, fs)
flat_ecg2, _ = reconstruct_time(ecg2_array, timestamps, fs)
flat_ecg3, _ = reconstruct_time(ecg3_array, timestamps, fs)
flat_ecg1 = auto_invert_if_needed(flat_ecg1, fs)
flat_ecg2 = auto_invert_if_needed(flat_ecg2, fs)
flat_ecg3 = auto_invert_if_needed(flat_ecg3, fs)

df_rest = pd.DataFrame([compute_hrv_metrics(ecg, fs, label)
                        for ecg, label in zip([flat_ecg1, flat_ecg2, flat_ecg3],
                                              ["Derivazione I", "Derivazione II", "Derivazione III"])])

plot_filtered_window_multideriv(
    [flat_ecg1, flat_ecg2, flat_ecg3],
    ["Derivazione I", "Derivazione II", "Derivazione III"],
    fs, t_full, start_sec=300, duration_sec=10,
    title_phase="Riposo", subject_number=subject_number, save_dir=output_dir
)

# ------------------- CYCLETTE -------------------
txt_path = os.path.join(base_dir, f"ECG_LOG_cyclette_{subject_number}.txt")
npz_path = os.path.join(base_dir, f"ECG_LOG_cyclette_{subject_number}.npz")
parse_txt_to_npz(txt_path, npz_path)
ecg1_array, ecg2_array, ecg3_array, timestamps, fs = load_ecg_data(npz_path)

flat_ecg1, t_full = reconstruct_time(ecg1_array, timestamps, fs)
flat_ecg2, _ = reconstruct_time(ecg2_array, timestamps, fs)
flat_ecg3, _ = reconstruct_time(ecg3_array, timestamps, fs)
flat_ecg1 = auto_invert_if_needed(flat_ecg1, fs)
flat_ecg2 = auto_invert_if_needed(flat_ecg2, fs)
flat_ecg3 = auto_invert_if_need
