from __future__ import annotations

import json
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
from scipy.signal import csd, find_peaks


def mac(phi_a: np.ndarray, phi_b: np.ndarray) -> float:
    """Modal Assurance Criterion (MAC) for two mode shape vectors.

    Works for real or complex vectors; returns value in [0, 1].
    """
    a = np.asarray(phi_a).reshape(-1)
    b = np.asarray(phi_b).reshape(-1)
    num = np.abs(np.vdot(a, b)) ** 2
    den = (np.vdot(a, a).real * np.vdot(b, b).real)
    return float(num / den) if den > 0 else 0.0


def _normalize_shape(phi: np.ndarray) -> np.ndarray:
    phi = np.asarray(phi).reshape(-1)
    # Fix phase/sign for consistency: make max-abs component positive real
    idx = int(np.argmax(np.abs(phi)))
    if phi[idx] != 0:
        phi = phi * np.exp(-1j * np.angle(phi[idx]))
    # If still complex, take real part (prototype simplification)
    phi_r = np.real(phi)
    n = np.linalg.norm(phi_r)
    return (phi_r / n) if n > 0 else phi_r


@dataclass
class FDDEstimate:
    frequency_hz: float
    shape: np.ndarray
    quality: float
    damping_ratio: Optional[float] = None


def fdd_spectrum(
    X: np.ndarray,
    fs: float,
    nperseg: int,
    noverlap: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute FDD spectrum.

    Returns:
      f: (K,) frequencies
      s1: (K,) first singular value of CSD matrix at each frequency
      v1: (K, M) first singular vector (complex) at each frequency
    """
    X = np.asarray(X, dtype=float)
    if X.ndim != 2:
        raise ValueError("X must be 2D (N, M)")
    N, M = X.shape
    if M < 1:
        raise ValueError("Need at least 1 channel")
    if noverlap is None:
        noverlap = nperseg // 2

    # Build CSD matrix for each frequency
    f = None
    S = None  # (K, M, M) complex
    for i in range(M):
        for j in range(i, M):
            fj, Pij = csd(
                X[:, i],
                X[:, j],
                fs=fs,
                nperseg=min(nperseg, N),
                noverlap=min(noverlap, max(0, min(nperseg, N) - 1)),
                detrend="constant",
                scaling="density",
                return_onesided=True,
            )
            if f is None:
                f = fj
                K = len(f)
                S = np.zeros((K, M, M), dtype=np.complex128)
            S[:, i, j] = Pij
            if i != j:
                S[:, j, i] = np.conj(Pij)

    assert f is not None and S is not None
    s1 = np.zeros(len(f), dtype=float)
    v1 = np.zeros((len(f), M), dtype=np.complex128)
    for k in range(len(f)):
        U, s, _Vh = np.linalg.svd(S[k, :, :], full_matrices=False)
        s1[k] = float(np.real(s[0]))
        v1[k, :] = U[:, 0]
    return f, s1, v1


def pick_fdd_modes(
    f: np.ndarray,
    s1: np.ndarray,
    v1: np.ndarray,
    freq_min: float,
    freq_max: float,
    max_modes: int,
    peak_prominence: float = 0.0,
) -> List[FDDEstimate]:
    mask = (f >= freq_min) & (f <= freq_max)
    if not np.any(mask):
        return []
    f2 = f[mask]
    s2 = s1[mask]
    v2 = v1[mask, :]

    # Peak picking on first singular value
    peaks, _props = find_peaks(s2, prominence=peak_prominence)
    if peaks.size == 0:
        return []
    # Sort peaks by height desc
    peaks = peaks[np.argsort(s2[peaks])[::-1]][:max_modes]

    out: List[FDDEstimate] = []
    baseline = float(np.median(s2)) if np.isfinite(s2).any() else 1.0
    baseline = baseline if baseline > 0 else 1.0

    for p in peaks:
        fp = float(f2[p])
        phi = _normalize_shape(v2[p, :])
        q = float(s2[p] / baseline)

        # Damping estimate via half-power bandwidth (very rough, prototype)
        damping = None
        try:
            target = s2[p] / np.sqrt(2.0)
            # search left
            li = p
            while li > 0 and s2[li] > target:
                li -= 1
            ri = p
            while ri < len(s2) - 1 and s2[ri] > target:
                ri += 1
            f1 = float(f2[li])
            f2b = float(f2[ri])
            if fp > 0 and f2b > f1:
                damping = float((f2b - f1) / (2.0 * fp))
        except Exception:
            damping = None

        out.append(FDDEstimate(frequency_hz=fp, shape=phi, quality=q, damping_ratio=damping))

    # Sort by frequency (nice for downstream clustering)
    out.sort(key=lambda x: x.frequency_hz)
    return out


def cluster_by_frequency(modes: List[FDDEstimate], rel_tol: float = 0.01) -> List[FDDEstimate]:
    """Cluster modes by frequency proximity and keep the best-quality representative."""
    if not modes:
        return []
    modes = sorted(modes, key=lambda m: m.frequency_hz)
    clusters: List[List[FDDEstimate]] = [[modes[0]]]
    for m in modes[1:]:
        last = clusters[-1][-1]
        if abs(m.frequency_hz - last.frequency_hz) / max(1e-9, last.frequency_hz) <= rel_tol:
            clusters[-1].append(m)
        else:
            clusters.append([m])

    reps: List[FDDEstimate] = []
    for c in clusters:
        rep = max(c, key=lambda x: x.quality)
        reps.append(rep)
    return reps


def identify_modes_fdd(
    X: np.ndarray,
    fs: float,
    *,
    freq_min: float = 0.1,
    freq_max: float = 20.0,
    max_modes: int = 6,
    nperseg_list: Optional[List[int]] = None,
    rel_cluster_tol: float = 0.01,
) -> List[FDDEstimate]:
    """Run a lightweight OMA pipeline based on Frequency Domain Decomposition (FDD).

    This is used as the default automated analyzer in the prototype.

    Returns a list of clustered modal estimates (frequency + shape + quality).
    """
    if nperseg_list is None:
        nperseg_list = [1024, 2048, 4096]

    all_modes: List[FDDEstimate] = []
    for nperseg in nperseg_list:
        f, s1, v1 = fdd_spectrum(X, fs=fs, nperseg=int(nperseg))
        picked = pick_fdd_modes(f, s1, v1, freq_min=freq_min, freq_max=freq_max, max_modes=max_modes)
        all_modes.extend(picked)

    return cluster_by_frequency(all_modes, rel_tol=rel_cluster_tol)


def shape_to_json(phi: np.ndarray) -> str:
    return json.dumps([float(x) for x in np.asarray(phi).reshape(-1).tolist()])


def shape_from_json(s: str) -> np.ndarray:
    return np.asarray(json.loads(s), dtype=float)


def track_modes(
    detected: List[FDDEstimate],
    existing_modes: List[Dict[str, Any]],
    *,
    mac_threshold: float = 0.9,
    freq_rel_tol: float = 0.05,
) -> List[Dict[str, Any]]:
    """Match detected modes to existing tracked modes.

    existing_modes: list of dicts with keys: mode_id, ref_frequency_hz, ref_shape_json

    Returns list of dicts per detected mode:
      - matched_mode_id (or None)
      - best_mac
    """
    out: List[Dict[str, Any]] = []
    ex = []
    for m in existing_modes:
        try:
            ex.append(
                (
                    int(m["mode_id"]),
                    float(m["ref_frequency_hz"]),
                    _normalize_shape(shape_from_json(m["ref_shape_json"])),
                )
            )
        except Exception:
            continue

    for det in detected:
        best = (None, 0.0)  # (mode_id, mac)
        for mode_id, fref, pref in ex:
            if abs(det.frequency_hz - fref) / max(1e-9, fref) > freq_rel_tol:
                continue
            mval = mac(det.shape, pref)
            if mval > best[1]:
                best = (mode_id, mval)
        if best[1] >= mac_threshold:
            out.append({"matched_mode_id": best[0], "best_mac": best[1]})
        else:
            out.append({"matched_mode_id": None, "best_mac": best[1]})
    return out
