#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Analisi ECG con tempo reale, HR protocollo, VO₂max e valutazione affidabilità derivazioni
Autore: Annarosa Scalcione
Data: 2025-09-13

NOTE:
- Questo script:
  1) effettua il parsing dei log .txt ECG (3 derivazioni, blocchi di 53 campioni) e salva in .npz
  2) ricostruisce l'asse temporale "reale" campione-per-campione
  3) corregge automaticamente la polarità delle derivazioni, se necessario
  4) mostra grafici grezzi e filtrati con R-peaks
  5) calcola HR su finestre del protocollo e stima il VO₂max (Ekblom-Bak)
  6) valuta la qualità del segnale per derivazione e calcola media affidabile
  7) confronta con VO₂ gold standard se presente nel file Excel
"""

# ==============================================================
# LIBRERIE
# ==============================================================
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import os
import neurokit2 as nk
from datetime import datetime


# ==============================================================
# FUNZIONI DI SUPPORTO
# ==============================================================

def parse_txt_to_npz(txt_path, npz_path):
    """
    Parsing del file .txt contenente ECG e salvataggio in formato .npz.

    Formato atteso delle righe:
      HH:MM:SS.xxxxxx [[<53 campioni ch1>] [<53 campioni ch2>] [<53 campioni ch3>]]
    """
    pattern = r"(\d{2}:\d{2}:\d{2}\.\d+)\s+\[\[(.*?)\]\s+\[(.*?)\]\s+\[(.*?)\]\]"
    with open(txt_path, "r") as file:
        text = file.read()
    matches = re.findall(pattern, text, re.DOTALL)

    timestamps, ecg1_list, ecg2_list, ecg3_list = [], [], [], []
    for ts, e1, e2, e3 in matches:
        try:
            ecg1 = np.fromstring(e1, sep=' ')
            ecg2 = np.fromstring(e2, sep=' ')
            ecg3 = np.fromstring(e3, sep=' ')
            if len(ecg1) == len(ecg2) == len(ecg3):
                ecg1_list.append(ecg1)
                ecg2_list.append(ecg2)
                ecg3_list.append(ecg3)
                timestamps.append(ts)
        except:
            print(f"Errore parsing blocco dopo {ts}")

    np.savez(npz_path, 
             timestamps=np.array(timestamps),
             ecg1=np.array(ecg1_list),
             ecg2=np.array(ecg2_list),
             ecg3=np.array(ecg3_list))
    print(f"File salvato: {npz_path}")


def load_ecg_data(filename):
    """
    Carica i dati ECG da un file .npz.

    Returns
    -------
    ecg1, ecg2, ecg3 : np.ndarray (blocchi x 53)
    timestamps       : np.ndarray di stringhe "HH:MM:SS.xxxxxx"
    fs               : int (frequenza di campionamento, 1000 Hz)
    """
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"Il file {filename} non esiste.")
    data = np.load(filename, allow_pickle=True)
    return data['ecg1'], data['ecg2'], data['ecg3'], data['timestamps'], 1000  # fs


def reconstruct_time(timestamps, fs, n_samples=53):
    """
    Ricostruisce il tempo reale campione per campione tramite interpolazione lineare
    tra i timestamp di testa-blocco, assumendo blocchi di n_samples campioni.

    Parametri
    ---------
    timestamps : array di stringhe "HH:MM:SS.xxxxxx"
    fs         : int, frequenza di campionamento (Hz)
    n_samples  : int, campioni per blocco (default 53)

    Ritorna
    -------
    t_full : np.ndarray, tempi in secondi per ogni campione del tracciato completo
    """
    rel_time = []
    for ts in timestamps:
        dt = datetime.strptime(str(ts), "%H:%M:%S.%f")
        sec = dt.hour*3600 + dt.minute*60 + dt.second + dt.microsecond/1e6
        rel_time.append(sec)
    rel_time = np.array(rel_time) - rel_time[0]

    t_full = []
    # Interpola i tempi all'interno di ciascun intervallo tra blocchi consecutivi
    for t0, t1 in zip(rel_time[:-1], rel_time[1:]):
        block_times = np.linspace(t0, t1, n_samples, endpoint=False)
        t_full.extend(block_times)
    # Ultimo blocco: genera tempi uniformi a partire dall'ultimo timestamp
    block_times = np.linspace(rel_time[-1], rel_time[-1]+(n_samples-1)/fs, n_samples)
    t_full.extend(block_times)
    return np.array(t_full)


def flatten_signals(ecg_array):
    """
    Appiattisce un array (blocchi x campioni) in un vettore 1D.
    """
    return ecg_array.flatten()


def auto_invert_if_needed(signal, fs, label="ECG", debug=True):
    """
    Rileva automaticamente se la derivazione è invertita e, in caso, la inverte.

    Strategia:
    1) Pulizia -> band-pass 5–20 Hz (QRS)
    2) Rilevo picchi su |QRS| (indipendente dalla polarità)
    3) Stimo il segno medio dei campioni ai picchi
    4) Se la maggioranza è negativa -> inverto
    """
    ecg_clean = nk.ecg_clean(signal, sampling_rate=fs)
    ecg_qrs = nk.signal_filter(ecg_clean, sampling_rate=fs,
                               lowcut=5, highcut=20, method="butterworth", order=3)

    _, rpeaks = nk.ecg_peaks(np.abs(ecg_qrs), sampling_rate=fs)
    locs = rpeaks.get("ECG_R_Peaks", np.array([], dtype=int))

    if len(locs) == 0:
        if debug:
            print(f"{label}: nessun R-peak trovato per il test di polarità (nessuna inversione applicata).")
        return signal

    vals = ecg_qrs[locs]
    frac_pos = np.mean(vals > 0)
    frac_neg = 1 - frac_pos

    if debug:
        med_pos = np.median(vals[vals > 0]) if np.any(vals > 0) else np.nan
        med_neg = np.median(-vals[vals < 0]) if np.any(vals < 0) else np.nan
        print(f"{label}: R-positivi {frac_pos:.2%} / R-negativi {frac_neg:.2%} "
              f"(|med pos|={med_pos:.3f}, |med neg|={med_neg:.3f})")

    # Soglia: se >60% dei picchi è negativo, invertiamo
    if frac_pos < 0.40:
        if debug:
            print(f"{label}: derivazione invertita automaticamente.")
        return -signal
    else:
        return signal


def plot_all_raw(ecg1, ecg2, ecg3, tt, start, dur):
    """
    Plot dei segnali ECG grezzi (3 derivazioni) nella stessa finestra [start, start+dur].
    """
    mask = (tt >= start) & (tt < start+dur)
    segs = [ecg1[mask], ecg2[mask], ecg3[mask]]
    tseg = tt[mask]
    labels = ["Derivazione I", "Derivazione II", "Derivazione III"]

    plt.figure(figsize=(15, 8))
    for i, (seg, lab) in enumerate(zip(segs, labels), start=1):
        plt.subplot(3, 1, i)
        plt.plot(tseg, seg, label="ECG grezzo")
        plt.title(f"{lab} — grezzo [{start}-{start+dur}s]")
        plt.xlabel("Tempo reale [s]")
        plt.ylabel("Ampiezza [a.u.]")
        plt.grid()
    plt.tight_layout()
    plt.show()


def plot_all_filtered(ecg1, ecg2, ecg3, tt, start, dur, fs):
    """
    Plot dei segnali ECG filtrati con marcatura dei R-peaks (3 derivazioni).
    """
    mask = (tt >= start) & (tt < start+dur)
    segs = [ecg1[mask], ecg2[mask], ecg3[mask]]
    tseg = tt[mask]
    labels = ["Derivazione I", "Derivazione II", "Derivazione III"]

    plt.figure(figsize=(15, 8))
    for i, (seg, lab) in enumerate(zip(segs, labels), start=1):
        ecg_clean = nk.ecg_clean(seg, sampling_rate=fs)
        _, rpeaks = nk.ecg_peaks(ecg_clean, sampling_rate=fs)
        rlocs = rpeaks["ECG_R_Peaks"]

        plt.subplot(3, 1, i)
        plt.plot(tseg, ecg_clean, label="ECG filtrato")
        plt.plot(tseg[rlocs], ecg_clean[rlocs], "ro", label="R-peaks")
        plt.title(f"{lab} — filtrato [{start}-{start+dur}s]")
        plt.xlabel("Tempo reale [s]")
        plt.ylabel("Ampiezza [a.u.]")
        plt.legend()
        plt.grid()
    plt.tight_layout()
    plt.show()


def compute_hr_protocol(sig, tt, fs, centers, label):
    """
    Calcola la HR media nelle finestre centrate sui tempi di protocollo.

    Parametri
    ---------
    sig     : array 1D, segnale ECG
    tt      : array 1D, tempi (s)
    fs      : int, frequenza di campionamento
    centers : lista di centri finestra (s); finestra fissa ±7.5 s
    label   : string, etichetta derivazione per stampa
    """
    results = {}
    for c in centers:
        mask = (tt >= c-7.5) & (tt <= c+7.5)
        seg = sig[mask]
        if len(seg) > 0:
            ecg_clean = nk.ecg_clean(seg, sampling_rate=fs)
            _, rpeaks = nk.ecg_peaks(ecg_clean, sampling_rate=fs)
            rlocs = rpeaks["ECG_R_Peaks"]
            hr = nk.ecg_rate(rlocs, sampling_rate=fs) if len(rlocs) > 1 else [np.nan]
            results[c] = np.nanmean(hr)
        else:
            results[c] = np.nan
    print(f"\n{label}")
    for c, hr in results.items():
        print(f"  HR @ {c:.0f}s = {hr:.1f} bpm")
    return results


def calcola_vo2max(sex, age, delta_hr, delta_po, hr_standard):
    """
    Stima VO₂max con formula Ekblom-Bak.
    sex: 'M' (maschio) o 'F' (femmina)
    """
    if sex == "M":
        return np.exp(
            (2.04900 - 0.00858 * age)
            - (0.90742 * delta_hr / delta_po)
            + (0.00178 * delta_po)
            - (0.00290 * hr_standard)
        )
    elif sex == "F":
        return np.exp(
            (1.84390 - 0.00673 * age)
            - (0.62578 * delta_hr / delta_po)
            + (0.00175 * delta_po)
            - (0.00471 * hr_standard)
        )
    else:
        raise ValueError("Sesso non valido. Usa 'M' o 'F'.")


def evaluate_signal_quality(ecg, fs, label="ECG", debug=True):
    """
    Valuta la qualità di una derivazione ECG (affidabile / non affidabile) con criteri:
      - numero sufficiente di R-peaks
      - HR medio plausibile e stabilità (std)
      - intervalli RR fisiologici
      - SNR sufficiente (QRS vs resto)
      - ampiezza media R-peaks
    Restituisce True/False.
    """
    ecg_clean = nk.ecg_clean(ecg, sampling_rate=fs)
    _, rpeaks = nk.ecg_peaks(ecg_clean, sampling_rate=fs)
    locs = rpeaks.get("ECG_R_Peaks", [])

    # 1) Numero minimo di battiti (>= 20 bpm equivalenti sulla durata)
    durata_min = len(ecg) / fs / 60
    if len(locs) < durata_min * 20:
        if debug: print(f"{label}: troppo pochi R-peaks ({len(locs)})")
        return False

    # 2) HR medio e stabilità
    hr = nk.ecg_rate(locs, sampling_rate=fs)
    hr_mean, hr_std = np.nanmean(hr), np.nanstd(hr)
    if not (35 < hr_mean < 210):
        if debug: print(f"{label}: HR medio fuori range ({hr_mean:.1f} bpm)")
        return False
    if hr_std > 25:
        if debug: print(f"{label}: HR instabile (std={hr_std:.1f})")
        return False

    # 3) Intervalli RR plausibili
    rr = np.diff(locs) * 1000 / fs  # in ms
    bad_rr = np.sum((rr < 350) | (rr > 1800)) / len(rr) if len(rr) > 0 else 1
    if bad_rr > 0.3:
        if debug: print(f"{label}: troppi intervalli RR anomali ({bad_rr*100:.1f}%)")
        return False

    # 4) SNR stimato nel band-pass QRS
    ecg_qrs = nk.signal_filter(ecg_clean, sampling_rate=fs, lowcut=5, highcut=20)
    power_qrs = np.mean(ecg_qrs**2)
    power_tot = np.mean(ecg_clean**2)
    snr = power_qrs / (power_tot - power_qrs + 1e-8)
    if snr < 0.1:
        if debug: print(f"{label}: SNR basso ({snr:.2f})")
        return False

    # 5) Ampiezza media R-peaks
    r_values = ecg_clean[locs] if len(locs) > 0 else []
    if len(r_values) == 0 or np.mean(np.abs(r_values)) < 0.1 * np.std(ecg_clean):
        if debug: print(f"{label}: R-peaks troppo piccoli o indistinti")
        return False

    if debug:
        print(f"{label}: affidabile (HR={hr_mean:.1f}±{hr_std:.1f}, RR_ok={100*(1-bad_rr):.1f}%, SNR={snr:.2f})")
    return True


def load_subject_data(excel_path, subject_id):
    """
    Carica i dati anagrafici e di test del soggetto dal file Excel.

    Richiede colonne: soggetto, sesso, eta, peso, po_high, (opzionale) vo2_gold.
    """
    df = pd.read_excel(excel_path)
    row = df.loc[df["soggetto"] == int(subject_id)]
    if row.empty:
        raise ValueError(f"Nessun soggetto {subject_id} trovato in {excel_path}")
    
    sex = str(row["sesso"].values[0]).upper()
    age = int(row["eta"].values[0])
    weight = float(row["peso"].values[0])
    po_high = float(row["po_high"].values[0])
    # ΔPO secondo mappatura ergometro
    delta_po = 1.079 * po_high - 31.618
    vo2_gold = float(row["vo2_gold"].values[0]) if "vo2_gold" in df.columns else None
    
    return sex, age, weight, po_high, delta_po, vo2_gold


# ==============================================================
# ESECUZIONE
# ==============================================================

# --- Input soggetto con controllo esistenza file
base_dir = "/Users/annarosascalcione/Desktop/università/Magistrale/Tesi Anna/codici/soggetti"

while True:
    subject_id = input("Inserisci il numero del soggetto: ").strip()
    txt_path = os.path.join(base_dir, f"ECG_LOG_cyclette_{subject_id}.txt")
    npz_path = os.path.join(base_dir, f"ECG_LOG_cyclette_{subject_id}.npz")

    if os.path.exists(txt_path):
        print(f"\nAnalisi soggetto {subject_id}")
        print(f"  TXT: {txt_path}")
        print(f"  NPZ: {npz_path}\n")
        break
    else:
        print(f"File {txt_path} non trovato. Riprova.\n")

# --- Step 1: parsing da txt a npz
parse_txt_to_npz(txt_path, npz_path)

# --- Step 2: caricamento npz
ecg1_array, ecg2_array, ecg3_array, timestamps, fs = load_ecg_data(npz_path)

# --- Step 3: ricostruzione tempo reale
t_full = reconstruct_time(timestamps, fs)
ecg1 = flatten_signals(ecg1_array)
ecg2 = flatten_signals(ecg2_array)
ecg3 = flatten_signals(ecg3_array)

# --- Step 3b: correzione automatica polarità
ecg1 = auto_invert_if_needed(ecg1, fs, "Derivazione I")
ecg2 = auto_invert_if_needed(ecg2, fs, "Derivazione II")
ecg3 = auto_invert_if_needed(ecg3, fs, "Derivazione III")

# --- Step 4: plot finestra di esempio
start_sec, dur_sec = 300, 10
plot_all_raw(ecg1, ecg2, ecg3, t_full, start_sec, dur_sec)
plot_all_filtered(ecg1, ecg2, ecg3, t_full, start_sec, dur_sec, fs)

# --- Step 5: HR protocollo (centri finestra a ±7.5s)
centers_4min = [195, 210, 225, 240]
centers_8min = [435, 450, 465, 480]
hr1 = compute_hr_protocol(ecg1, t_full, fs, centers_4min+centers_8min, "Derivazione I")
hr2 = compute_hr_protocol(ecg2, t_full, fs, centers_4min+centers_8min, "Derivazione II")
hr3 = compute_hr_protocol(ecg3, t_full, fs, centers_4min+centers_8min, "Derivazione III")

# --- Step 6: VO₂max per derivazione e media solo affidabili
excel_path = "/Users/annarosascalcione/Desktop/università/Magistrale/Tesi Anna/codici/soggetti/dati_soggetti.xlsx"
sex, age, weight, po_high, delta_po, vo2_gold = load_subject_data(excel_path, subject_id)

print(f"\nDati soggetto {subject_id}:")
print(f"  Sesso={sex}, Età={age}, Peso={weight} kg, PO_high={po_high} W, ΔPO={delta_po:.2f}")

vo2_results = {}
reliable_vo2 = {}

print("\n=== Risultati VO₂max per derivazione ===")
for ecg, hr_dict, label in zip(
        [ecg1, ecg2, ecg3],
        [hr1, hr2, hr3],
        ["Derivazione I", "Derivazione II", "Derivazione III"]):

    if all(~np.isnan(list(hr_dict.values()))):
        hr_standard = np.nanmean([hr_dict[c] for c in centers_4min])
        delta_hr = np.nanmean([hr_dict[c] for c in centers_8min]) - hr_standard
        vo2max = calcola_vo2max(sex, age, delta_hr, delta_po, hr_standard)
        vo2max_mlkgmin = vo2max * 1000 / weight
        vo2_results[label] = vo2max_mlkgmin

        # Valutazione affidabilità della derivazione
        reliable = evaluate_signal_quality(ecg, fs, label)
        if reliable:
            reliable_vo2[label] = vo2max_mlkgmin
            print(f"{label} - VO₂max stimato = {vo2max_mlkgmin:.2f} ml/kg/min AFFIDABILE")
        else:
            print(f"{label} - VO₂max stimato = {vo2max_mlkgmin:.1f} ml/kg/min ATTENZIONE: qualità segnale potenzialmente NON affidabile")
    else:
        vo2_results[label] = np.nan
        print(f"{label} - HR insufficienti per calcolare il VO₂max.")

# --- Step 7: media con controllo affidabilità
if len(reliable_vo2) == len(vo2_results):  
    # Tutte affidabili
    vo2_mean = np.mean(list(reliable_vo2.values()))
    print(f"\nTutte le derivazioni affidabili.")
    print(f"VO₂max medio = {vo2_mean:.2f} ml/kg/min")
else:
    # Alcune non affidabili
    print("\nAlcune derivazioni non sono affidabili.")
    print("   Valori stimati per ciascuna derivazione:")
    for i, (label, val) in enumerate(vo2_results.items(), start=1):
        print(f"   {i}. {label}: {val:.2f} ml/kg/min {'(affidabile)' if label in reliable_vo2 else '(non affidabile)'}")

    # Scelta manuale della derivazione (facoltativa)
    choice = input("Vuoi scegliere una derivazione specifica da usare? (1/2/3 o Invio per media affidabili): ").strip()
    if choice in ["1", "2", "3"]:
        chosen_label = list(vo2_results.keys())[int(choice)-1]
        vo2_mean = vo2_results[chosen_label]
        print(f"Usata solo {chosen_label}: VO₂max = {vo2_mean:.2f} ml/kg/min")
    else:
        if reliable_vo2:
            vo2_mean = np.mean(list(reliable_vo2.values()))
            print(f"VO₂max medio (solo derivazioni affidabili) = {vo2_mean:.2f} ml/kg/min")
        else:
            vo2_mean = np.nan
            print("Nessuna derivazione affidabile disponibile.")

# --- Step 8: confronto con gold standard (se presente)
if vo2_gold is not None and not np.isnan(vo2_mean):
    err = abs(vo2_mean - vo2_gold) / vo2_gold * 100
    print(f"\nVO₂max gold standard = {vo2_gold:.2f} ml/kg/min")
    print(f"Errore percentuale = {err:.2f}%")
elif vo2_gold is None:
    print("\nNessun valore gold standard disponibile in Excel.")
else:
    print("\nNessun valore VO₂ stimato disponibile per calcolare l'errore.")
