"""
Main module for streaming time tag acquisition, real-time delay processing, 
dynamic histogram generation, and final autocorrelation analysis.
"""

import TimeTagger
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, correlate

# --- 1. Time Tag Stream Acquisition Function ---

def measure_timetag_stream(duration_ms: int, channels_to_measure: list[int]) -> tuple[np.ndarray, np.ndarray]:
    """
    Measures time tags from the specified channels for the given duration 
    using TimeTagStream and returns two NumPy arrays: channels and timestamps.

    Args:
        duration_ms (int): The duration of the measurement in milliseconds.
        channels_to_measure (list[int]): A list of channel numbers to measure.

    Returns:
        tuple[np.ndarray, np.ndarray]: A tuple containing two NumPy arrays:
                                        - The first array contains the channel numbers of the events.
                                        - The second array contains the event timestamps in picoseconds.
                                        Returns empty arrays in case of an error.
    """
    tagger = None  # Initialize tagger outside try/except block
    try:
        # Create a TimeTagger instance
        tagger = TimeTagger.createTimeTagger()

        # Optional: Enable internal test signals on specified channels
        # tagger.setTestSignal(channels_to_measure, True)

        event_buffer_size = 1000000  # Buffer size (can be adapted)

        # Create the TimeTagStream object
        stream = TimeTagger.TimeTagStream(tagger=tagger,
                                         n_max_events=event_buffer_size,
                                         channels=channels_to_measure)

        # Start acquisition (duration converted to picoseconds)
        # 1 ms = 1E9 ps
        stream.startFor(int(duration_ms * 1E9))

        # Initialize lists to accumulate all events across all data chunks
        all_channels = []
        all_timestamps = []

        while stream.isRunning():
            # Get the data chunk from the stream
            data = stream.getData()
            if data.size > 0:
                channels = data.getChannels()
                timestamps = data.getTimestamps()
                all_channels.extend(channels)
                all_timestamps.extend(timestamps)
                
        return np.array(all_channels), np.array(all_timestamps)

    except Exception as e:
        print(f"An error occurred during TimeTagStream measurement: {e}")
        return np.array([]), np.array([])
    
    finally:
        # Ensure the TimeTagger connection is freed
        if tagger:
            TimeTagger.freeTimeTagger(tagger)


# --- 2. Delay Processing and Histogram Functions ---

def process_delays_per_trigger_binned_clicks_simple(timestamps_us: np.ndarray, channel: np.ndarray, window_us: float = 20, hist_bins: int = 50) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
    """
    Calculates the delays of binned clicks relative to each trigger (unbinned) 
    within a specified time window.
    Returns all relative delays, histogram counts, and bin edges.

    Args:
        timestamps_us (np.ndarray): Array of timestamps in microseconds ($\mu$s).
        channel (np.ndarray): Array of channels corresponding to the timestamps.
        window_us (float): Time window for delay calculation in microseconds ($\mu$s).
        hist_bins (int): Number of bins for the histogram.

    Returns:
        np.ndarray: Array containing all calculated delays (in $\mu$s).
        np.ndarray: Final histogram bin values (counts).
        np.ndarray: Final histogram bin edges. Returns None if data is insufficient.
    """
    # Assuming Channel 1 is the reference/trigger (e.g., Sync or start pulse)
    mask_ref = (channel == 1)
    # Assuming Channel 2 is the click channel (e.g., Detection event)
    mask_cl = (channel == 2)

    trig_all_us = timestamps_us[mask_ref]
    click_all_us = timestamps_us[mask_cl]

    if trig_all_us.size == 0 or click_all_us.size == 0:
        return np.array([]), np.array([]), None

    # --- CLICK BINNING (Resolution: 0.0001 microseconds = 100 ps) ---
    bin_size_us = 0.0001
    click_binned_us = np.round(click_all_us / bin_size_us) * bin_size_us

    all_delays_us = np.array([])

    for T_us in trig_all_us:
        # Define the search window for clicks relative to the current trigger T_us
        lower_bound_click_us = T_us
        upper_bound_click_us = T_us + window_us
        
        # Filter clicks within the window
        relevant_clicks_us = click_binned_us[(click_binned_us >= lower_bound_click_us) & (click_binned_us < upper_bound_click_us)]
        
        # Calculate the delay (click_time - trigger_time)
        delays_us = relevant_clicks_us - T_us
        all_delays_us = np.concatenate((all_delays_us, delays_us))

    # Calculate the histogram for the current chunk's delays
    histogram_counts, hist_bin_edges = np.histogram(all_delays_us, bins=hist_bins, range=(0, window_us))

    return all_delays_us, histogram_counts, hist_bin_edges


def initialize_histogram(bin_edges: np.ndarray, title: str = "Click Delays Relative to Trigger", xlabel: str = "Delay ($\mu$s)", ylabel: str = "Counts") -> tuple[plt.Figure, plt.Axes, plt.Rectangle]:
    """
    Initializes the figure and axes for the dynamic histogram plot.
    
    Args:
        bin_edges (np.ndarray): The edges of the histogram bins.
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.

    Returns:
        matplotlib.figure.Figure: Figure object.
        matplotlib.axes.Axes: Axes object.
        matplotlib.patches.Rectangle: Histogram bars object.
    """
    fig, ax = plt.subplots()
    # Initialize bars with zero height
    bars = ax.bar(bin_edges[:-1], np.zeros_like(bin_edges[:-1]), width=np.diff(bin_edges), align='edge')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    plt.ion() # Turn on interactive plotting
    return fig, ax, bars


def update_histogram(ax: plt.Axes, bars: plt.Rectangle, new_counts: np.ndarray):
    """
    Updates the heights of the existing histogram bars with new cumulative counts.

    Args:
        ax (matplotlib.axes.Axes): Histogram Axes object.
        bars (matplotlib.patches.Rectangle): Histogram bars object.
        new_counts (np.ndarray): New cumulative counts for the histogram.
    """
    for bar, count in zip(bars, new_counts):
        bar.set_height(count)
    ax.relim()          # Recalculate axis limits based on new data
    ax.autoscale_view() # Adjust the view to fit the new limits
    plt.draw()
    plt.pause(0.001)    # Short pause for visual update


# --- 3. Autocorrelation Functions ---

def calculate_histogram_autocorrelation(hist_counts: np.ndarray) -> np.ndarray:
    """
    Calculates the normalized autocorrelation of a histogram's counts.

    Args:
        hist_counts (np.ndarray): The array of histogram counts.

    Returns:
        np.ndarray: The array containing normalized autocorrelation values.
    """
    n = len(hist_counts)
    if n <= 1:
        return np.array([1.0])

    # Mean-center and normalize the data for proper autocorrelation calculation
    mean_counts = np.mean(hist_counts)
    std_counts = np.std(hist_counts)
    # Avoid division by zero
    hist_normalized = (hist_counts - mean_counts) / std_counts if std_counts != 0 else hist_counts - mean_counts

    # Calculate the autocorrelation using cross-correlation (mode='full')
    autocorr = correlate(hist_normalized, hist_normalized, mode='full', method='auto')
    
    # Take only the positive lag side (from lag 0) and normalize the peak to 1
    # Note: Autocorrelation is symmetric, the peak is in the middle of 'full' mode result
    center_index = autocorr.size // 2
    autocorr_positive_lag = autocorr[center_index:] / autocorr[center_index]

    return autocorr_positive_lag

def plot_histogram_autocorrelation(autocorr: np.ndarray, bin_width_us: float, title: str = "Autocorrelation of Delay Histogram",
                                   xlabel: str = "Lag ($\mu$s)", ylabel: str = "Autocorrelation", max_lag_us: float | None = None):
    """
    Visualizes the histogram's autocorrelation function and highlights the first non-zero peak.

    Args:
        autocorr (np.ndarray): The array containing autocorrelation values.
        bin_width_us (float): The width of each histogram bin in microseconds ($\mu$s).
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        max_lag_us (float, optional): The maximum lag to display in microseconds ($\mu$s).
    """
    lags_us = np.arange(len(autocorr)) * bin_width_us
    plt.figure()
    plt.plot(lags_us, autocorr)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(True)

    # Find peaks in the autocorrelation (starting from lag > 0, index 1)
    peaks, _ = find_peaks(autocorr[1:], prominence=0.1)
    
    if peaks.size > 0:
        first_peak_index = peaks[0] + 1  # Add 1 because we started searching from index 1
        delay_us = lags_us[first_peak_index]
        plt.plot(delay_us, autocorr[first_peak_index], 'ro', label=f'Peak at Lag $\\approx$ {delay_us:.2f} $\mu$s')
        plt.legend()

    if max_lag_us is not None:
        plt.xlim(0, max_lag_us)

    plt.show()


# --- 4. Main Execution Function ---

def main(total_duration_ms: int = 1000, measurement_interval_ms: int = 100, channels: list[int] = [1, 2], window_us: float = 20, hist_bins: int = 2000):
    """
    Main function to acquire TimeTagStream data at regular intervals, 
    process delays, dynamically update a cumulative histogram, and 
    calculate the final autocorrelation.
    """
    # Calculate the total number of measurement iterations
    num_iterations = total_duration_ms // measurement_interval_ms
    
    first_iteration = True
    fig_hist, ax_hist, bars_hist = None, None, None
    bin_edges_global = None
    cumulative_histogram_counts = None  # Array to accumulate histogram counts

    print(f"Starting acquisition for {total_duration_ms} ms ({num_iterations} intervals of {measurement_interval_ms} ms)...")

    for i in range(num_iterations):
        print(f"Acquisition interval {i+1}/{num_iterations}...")
        
        # Measure time tags for the specified interval
        channels_data, timestamps_ps = measure_timetag_stream(measurement_interval_ms, channels)

        if timestamps_ps.size > 0:
            # Convert timestamps from picoseconds to microseconds
            timestamps_us = timestamps_ps / 1e6
            
            # Process the delays for the current chunk
            _, histogram_counts_this_chunk, bin_edges = process_delays_per_trigger_binned_clicks_simple(
                timestamps_us, channels_data, window_us, hist_bins
            )

            if histogram_counts_this_chunk is not None and bin_edges is not None:
                if first_iteration:
                    # Initialize the histogram plot on the first run
                    title = f"Cumulative Click Delays Relative to Trigger (Window: {window_us} $\mu$s)"
                    fig_hist, ax_hist, bars_hist = initialize_histogram(bin_edges, title=title)
                    bin_edges_global = bin_edges
                    cumulative_histogram_counts = histogram_counts_this_chunk.copy()  # Initialize cumulative counts
                    
                    update_histogram(ax_hist, bars_hist, cumulative_histogram_counts)
                    ax_hist.set_xlim(0, 6)  # Set initial X-axis limits (adjust as needed)
                    plt.show(block=False)
                    first_iteration = False
                    
                elif len(histogram_counts_this_chunk) == len(bin_edges_global) - 1:
                    # Subsequent iterations: update cumulative counts and plot
                    cumulative_histogram_counts += histogram_counts_this_chunk
                    update_histogram(ax_hist, bars_hist, cumulative_histogram_counts)
                else:
                    print("Error: Length of new counts does not match histogram bin size.")
            else:
                print("No valid histogram data generated in this iteration.")
        else:
            print("No events detected in this iteration.")

        # Wait for the measurement interval (redundant if using stream.startFor, but good for timing)
        # time.sleep(measurement_interval_ms / 1000)

    # --- Final Autocorrelation Calculation and Plotting ---
    if cumulative_histogram_counts is not None and bin_edges_global is not None:
        autocorr_result = calculate_histogram_autocorrelation(cumulative_histogram_counts)
        bin_width_us = np.diff(bin_edges_global)[0]
        
        plot_histogram_autocorrelation(autocorr_result, bin_width_us,
                                        title="Autocorrelation of Total Delay Histogram",
                                        xlabel="Lag ($\mu$s)", ylabel="Normalized Autocorrelation",
                                        max_lag_us=6)  # Limit lag display to 6 µs (adjust as needed)
    else:
        print("No cumulative histogram data available to calculate autocorrelation.")

    if fig_hist:
        plt.show(block=True) # Keep the final plots open

if __name__ == "__main__":
    main()