from __future__ import annotations

import time
from datetime import datetime, timezone

import numpy as np
from influxdb_client import Point, WritePrecision

from shm_shared.events import EventBus, parse_payload
from shm_shared.influx import get_influx_client, write_points
from shm_shared.logging import setup_logging
from shm_shared.metadata_client import (
    create_analysis_run,
    complete_analysis_run,
    fail_analysis_run,
    get_config,
    list_oma_modes,
    create_oma_mode,
    mark_oma_mode_seen,
    create_oma_observation,
)
from shm_shared.oma import identify_modes_fdd, track_modes, shape_to_json, shape_from_json, _normalize_shape
from shm_shared.settings import settings

logger = setup_logging("analysis.worker")

# Consume cleaned_ready
bus_in = EventBus(stream=settings.STREAM_CLEANED_READY, group=settings.GROUP_ANALYSIS)
# Publish analysis_done
bus_out = EventBus(stream=settings.STREAM_ANALYSIS_DONE, group=None)

CONSUMER = "analysis_worker"


def _iso_utc(dt: datetime) -> str:
    return dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _query_cleaned_series(structure: str, start: datetime, stop: datetime) -> dict[str, tuple[np.ndarray, np.ndarray]]:
    """Query cleaned accel series for all sensors of a structure.

    Returns dict: sensor_name -> (times_ns, values)
    """
    client = get_influx_client()
    try:
        q = client.query_api()
        flux = f'''
from(bucket: "{settings.CLEANED_BUCKET}")
  |> range(start: time(v: "{_iso_utc(start)}"), stop: time(v: "{_iso_utc(stop)}"))
  |> filter(fn: (r) => r._measurement == "accel")
  |> filter(fn: (r) => r.Structure == "{structure}")
  |> keep(columns: ["_time","_value","Sensor"])
  |> sort(columns: ["_time"])
'''
        tables = q.query(flux, org=settings.INFLUX_ORG)
        out: dict[str, list[tuple[int, float]]] = {}
        for t in tables:
            for rec in t.records:
                s = str(rec.values.get("Sensor") or rec.get("Sensor") or "")
                if not s:
                    continue
                ts_ns = int(rec["_time"].timestamp() * 1e9)
                out.setdefault(s, []).append((ts_ns, float(rec["_value"])))
        series: dict[str, tuple[np.ndarray, np.ndarray]] = {}
        for s, items in out.items():
            if not items:
                continue
            times = np.asarray([it[0] for it in items], dtype=np.int64)
            vals = np.asarray([it[1] for it in items], dtype=float)
            series[s] = (times, vals)
        return series
    finally:
        client.close()


def _align_matrix(series: dict[str, tuple[np.ndarray, np.ndarray]], sensor_order: list[str] | None = None) -> tuple[np.ndarray, np.ndarray, list[str]]:
    if not series:
        raise RuntimeError("No cleaned data returned from Influx.")
    sensors = sensor_order[:] if sensor_order else sorted(series.keys())
    sensors = [s for s in sensors if s in series]
    if not sensors:
        raise RuntimeError("No matching sensors in queried cleaned data.")
    min_len = min(len(series[s][1]) for s in sensors)
    if min_len < 10:
        raise RuntimeError("Cleaned data window is too short.")
    # Reference time is first sensor
    t_ref = series[sensors[0]][0][:min_len]
    X = np.stack([series[s][1][:min_len] for s in sensors], axis=1)
    return t_ref, X, sensors


def _write_mode_point(structure: str, mode_id: int, freq_hz: float, mac_val: float | None, quality: float | None, t_ns: int) -> int:
    pt = (
        Point("oma_modes")
        .tag("Structure", structure)
        .tag("ModeID", str(mode_id))
        .field("frequency_hz", float(freq_hz))
        .field("mac", float(mac_val) if mac_val is not None else float("nan"))
        .field("quality", float(quality) if quality is not None else float("nan"))
        .time(int(t_ns), WritePrecision.NS)
    )
    return write_points([pt], bucket=settings.OMA_BUCKET, chunk_size=1000)


def _analysis_params() -> dict:
    cfg = {}
    try:
        cfg = get_config("analysis") or {}
    except Exception:
        cfg = {}
    def _f(key, default):
        try:
            return float(cfg.get(key, default))
        except Exception:
            return float(default)
    def _i(key, default):
        try:
            return int(float(cfg.get(key, default)))
        except Exception:
            return int(default)
    return {
        "freq_min": _f("freq_min", 0.1),
        "freq_max": _f("freq_max", 20.0),
        "max_modes": _i("max_modes", 6),
        "mac_threshold": _f("mac_threshold", 0.90),
        "freq_rel_tol": _f("freq_rel_tol", 0.05),
        "cluster_tol": _f("cluster_tol", 0.01),
    }


def main():
    logger.info("OMA analysis worker starting (Redis Streams)...")
    bus_in.ensure_group()

    for mid, fields in bus_in.claim_stale(CONSUMER, min_idle_ms=60000):
        try:
            _handle(mid, fields)
        except Exception as e:
            logger.exception(f"Failed to process stale message {mid}: {e}")

    while True:
        msg = bus_in.consume_one(CONSUMER, block_ms=5000)
        if not msg:
            continue
        mid, fields = msg
        try:
            _handle(mid, fields)
        except Exception as e:
            logger.exception(f"Failed to process message {mid}: {e}")
            time.sleep(1)


def _handle(message_id: str, fields: dict[str, str]):
    payload = parse_payload(fields)
    if payload.get("type") != "cleaned_ready":
        bus_in.ack(message_id)
        return

    upload_id = int(payload["upload_id"])
    structure_id = int(payload.get("structure_id") or 1)
    structure_name = payload.get("structure_name") or settings.MONITORED_STRUCTURE

    start_s = payload.get("start_time")
    end_s = payload.get("end_time")
    if not start_s or not end_s:
        bus_in.ack(message_id)
        return

    start = datetime.fromisoformat(str(start_s).replace("Z", "+00:00"))
    end = datetime.fromisoformat(str(end_s).replace("Z", "+00:00"))
    cleaned_fs = float(payload.get("cleaned_fs") or 50.0)

    run_id = create_analysis_run(
        {
            "structure_id": structure_id,
            "upload_id": upload_id,
            "status": "running",
            "notes": "auto_oma_fdd",
        }
    )

    try:
        params = _analysis_params()
        series = _query_cleaned_series(structure_name, start, end)
        t_ns, X, sensors = _align_matrix(series, sensor_order=payload.get("sensor_names") if isinstance(payload.get("sensor_names"), list) else None)

        detected = identify_modes_fdd(
            X,
            fs=cleaned_fs,
            freq_min=params["freq_min"],
            freq_max=params["freq_max"],
            max_modes=params["max_modes"],
            rel_cluster_tol=params["cluster_tol"],
        )

        existing = list_oma_modes(structure_id=structure_id)
        matches = track_modes(
            detected,
            existing,
            mac_threshold=params["mac_threshold"],
            freq_rel_tol=params["freq_rel_tol"],
        )

        # Persist: modes + observations + Influx points
        end_ns = int(end.timestamp() * 1e9)
        persisted = []
        for det, match in zip(detected, matches):
            mode_id = match["matched_mode_id"]
            best_mac = float(match["best_mac"])

            if mode_id is None:
                created = create_oma_mode(
                    {
                        "structure_id": structure_id,
                        "name": None,
                        "ref_frequency_hz": det.frequency_hz,
                        "ref_shape_json": shape_to_json(det.shape),
                    }
                )
                mode_id = int(created["mode_id"])
                best_mac = 1.0
            else:
                # Update reference with a light EMA to follow slow drift (prototype)
                ex = next((m for m in existing if int(m.get("mode_id")) == int(mode_id)), None)
                if ex:
                    f_ref = float(ex.get("ref_frequency_hz") or det.frequency_hz)
                    phi_ref = _normalize_shape(shape_from_json(ex.get("ref_shape_json") or shape_to_json(det.shape)))
                    alpha = 0.2
                    f_new = (1 - alpha) * f_ref + alpha * det.frequency_hz
                    phi_new = _normalize_shape((1 - alpha) * phi_ref + alpha * det.shape)
                    mark_oma_mode_seen(int(mode_id), {"ref_frequency_hz": f_new, "ref_shape_json": shape_to_json(phi_new)})
                else:
                    mark_oma_mode_seen(int(mode_id), {})

            create_oma_observation(
                {
                    "run_id": run_id,
                    "mode_id": int(mode_id),
                    "frequency_hz": float(det.frequency_hz),
                    "damping_ratio": float(det.damping_ratio) if det.damping_ratio is not None else None,
                    "mac": float(best_mac) if best_mac is not None else None,
                    "quality": float(det.quality) if det.quality is not None else None,
                }
            )
            _write_mode_point(structure_name, int(mode_id), float(det.frequency_hz), float(best_mac), float(det.quality), end_ns)
            persisted.append(int(mode_id))

        complete_analysis_run(run_id, notes=f"modes={len(persisted)} sensors={len(sensors)} fs={cleaned_fs:.3f}")
        bus_out.publish(
            {
                "type": "analysis_done",
                "upload_id": upload_id,
                "structure_id": structure_id,
                "structure_name": structure_name,
                "end_time": end.isoformat(),
                "modes_detected": len(persisted),
            }
        )

        bus_in.ack(message_id)
        logger.info(f"OMA analysis complete upload_id={upload_id} run_id={run_id} modes={len(persisted)}")

    except Exception as e:
        logger.exception(f"OMA analysis failed upload_id={upload_id}: {e}")
        try:
            fail_analysis_run(run_id, str(e))
        except Exception:
            pass
        # leave unacked


if __name__ == "__main__":
    main()
