from __future__ import annotations

from typing import Optional, Tuple

import numpy as np
from scipy.signal import butter, filtfilt


def _poly_detrend(x: np.ndarray, order: int = 1) -> np.ndarray:
    """Remove a polynomial trend (default: linear) along time for each column."""
    if order <= 0:
        return x - np.mean(x, axis=0, keepdims=True)

    n, _ = x.shape
    t = np.arange(n, dtype=float)
    T = np.vander(t, N=order + 1, increasing=True)
    coeffs, *_ = np.linalg.lstsq(T, x, rcond=None)
    trend = T @ coeffs
    return x - trend


def butter_band_filter(
    data: np.ndarray,
    fs: float,
    hp: Optional[float],
    lp: Optional[float],
    order: int = 4,
) -> np.ndarray:
    """Butterworth high/low/band-pass with zero phase (filtfilt)."""
    if hp is None and lp is None:
        return data

    nyq = 0.5 * fs
    if hp is not None and lp is not None:
        wn = [hp / nyq, lp / nyq]
        btype = "bandpass"
    elif hp is not None:
        wn = hp / nyq
        btype = "highpass"
    else:
        wn = lp / nyq
        btype = "lowpass"

    b, a = butter(order, wn, btype=btype)
    return filtfilt(b, a, data, axis=0)


def clean_signal_matrix(
    signal: np.ndarray,
    fc: float,
    ss: int = 2,
    hp: Optional[float] = 0.5,
    lp: Optional[float] = 15.0,
    filter_order: int = 4,
    remove_mean: bool = True,
    detrend_order: int = 1,
) -> Tuple[np.ndarray, float]:
    """SASI-style cleaning pipeline.

    Returns (cleaned_signal, fs_new).
    """
    x = np.asarray(signal, dtype=float)

    if remove_mean:
        x = x - np.mean(x, axis=0, keepdims=True)

    if detrend_order is not None and detrend_order >= 0:
        x = _poly_detrend(x, order=detrend_order)

    x = butter_band_filter(x, fs=fc, hp=hp, lp=lp, order=filter_order)

    ss = max(1, int(ss))
    y = x[::ss, :]
    fs_new = fc / ss

    return y, fs_new
