from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Optional, Tuple, List

import numpy as np
from influxdb_client import Point, WritePrecision

from shm_shared.settings import settings
from shm_shared.logging import setup_logging
from shm_shared.metadata_client import get_config, create_cleaning_run, complete_cleaning_run, fail_cleaning_run, get_upload, list_sensors
from shm_shared.cleaning import clean_signal_matrix
from shm_shared.influx import get_influx_client, write_points

logger = setup_logging("cleaning.pipeline")


def _iso_utc(dt: datetime) -> str:
    """Convert datetime to ISO 8601 UTC format for Flux queries."""
    return dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _query_sensor(bucket: str, structure_name: str, sensor_name: str, start: datetime, stop: datetime) -> Tuple[np.ndarray, np.ndarray]:
    """Query raw accelerometer data from InfluxDB for a specific sensor.
    
    Args:
        bucket: InfluxDB bucket name (e.g., settings.RAW_BUCKET)
        structure_name: Structure identifier
        sensor_name: Sensor identifier
        start: Query start time
        stop: Query stop time
    
    Returns:
        Tuple of (times_ns, values) where:
        - times_ns: numpy array of int64 nanosecond timestamps
        - values: numpy array of float sensor values
    """
    client = get_influx_client()
    try:
        q = client.query_api()
        flux = f'''
from(bucket: "{bucket}")
  |> range(start: time(v: "{_iso_utc(start)}"), stop: time(v: "{_iso_utc(stop)}"))
  |> filter(fn: (r) => r._measurement == "accel")
  |> filter(fn: (r) => r.Structure == "{structure_name}")
  |> filter(fn: (r) => r.Sensor == "{sensor_name}")
  |> keep(columns: ["_time","_value"])
  |> sort(columns: ["_time"])
'''
        tables = q.query(flux, org=settings.INFLUX_ORG)
        times_list: List[int] = []
        values_list: List[float] = []
        
        for table in tables:
            for record in table.records:
                ts_ns = int(record["_time"].timestamp() * 1e9)
                times_list.append(ts_ns)
                values_list.append(float(record["_value"]))
        
        if not times_list:
            logger.warning(f"No data found for {structure_name}/{sensor_name}")
            return np.array([], dtype=np.int64), np.array([], dtype=float)
        
        return np.asarray(times_list, dtype=np.int64), np.asarray(values_list, dtype=float)
    finally:
        client.close()


def _write_cleaned_series(structure_name: str, sensor_name: str, times: np.ndarray, values: np.ndarray) -> int:
    """Write cleaned accelerometer data to InfluxDB.
    
    Args:
        structure_name: Structure identifier
        sensor_name: Sensor identifier
        times: numpy array of int64 nanosecond timestamps
        values: numpy array of float sensor values
    
    Returns:
        Number of points written
    """
    if len(times) == 0 or len(values) == 0:
        logger.warning(f"No data to write for {structure_name}/{sensor_name}")
        return 0
    
    points = []
    for ts_ns, val in zip(times, values):
        pt = (
            Point("accel")
            .tag("Structure", structure_name)
            .tag("Sensor", sensor_name)
            .field("value", float(val))
            .time(int(ts_ns), WritePrecision.NS)
        )
        points.append(pt)
    
    return write_points(points, bucket=settings.CLEANED_BUCKET)


@dataclass
class CleaningParams:
    hp: Optional[float]
    lp: Optional[float]
    downsample: int
    filter_order: int
    detrend_order: int


def _as_float(v) -> Optional[float]:
    if v in (None, "", "null", "None"):
        return None
    try:
        return float(v)
    except Exception:
        return None


def _as_int(v, default: int) -> int:
    try:
        return int(float(v))
    except Exception:
        return int(default)


def get_cleaning_params() -> CleaningParams:
    """Load runtime params from metadata_api config (service='cleaning')."""
    cfg = {}
    try:
        cfg = get_config("cleaning") or {}
    except Exception:
        cfg = {}

    hp = _as_float(cfg.get("hp"))
    lp = _as_float(cfg.get("lp"))

    return CleaningParams(
        hp=hp if hp is not None else settings.CLEANING_HP,
        lp=lp if lp is not None else settings.CLEANING_LP,
        downsample=_as_int(cfg.get("downsample"), settings.CLEANING_DOWNSAMPLE),
        filter_order=_as_int(cfg.get("filter_order"), settings.CLEANING_FILTER_ORDER),
        detrend_order=_as_int(cfg.get("detrend_order"), settings.CLEANING_DETREND_ORDER),
    )



def clean_upload_influx(upload_id: int, structure_name_override: Optional[str] = None, sensor_names: Optional[List[str]] = None, original_fs: Optional[float] = None) -> Dict[str, object]:
    """Clean an upload window from Raw bucket to Cleaned bucket and log to metadata_api."""
    params = get_cleaning_params()
    up = get_upload(upload_id)

    structure_id = int(up["structure_id"])
    structure_name = structure_name_override or up.get("structure_name") or settings.MONITORED_STRUCTURE

    start_s = up.get("file_start_time") or up.get("start_time")
    stop_s = up.get("file_end_time") or up.get("end_time")
    if not start_s or not stop_s:
        raise RuntimeError("Upload does not include file_start_time/file_end_time")

    start = datetime.fromisoformat(start_s.replace("Z", "+00:00"))
    stop = datetime.fromisoformat(stop_s.replace("Z", "+00:00"))

    original_fs = float(original_fs or 100.0)

    cleaning_id = create_cleaning_run(
        {
            "upload_id": upload_id,
            "structure_id": structure_id,
            "status": "running",
            "hp_cutoff": params.hp,
            "lp_cutoff": params.lp,
            "downsample_factor": params.downsample,
            "filter_order": params.filter_order,
            "detrend_order": params.detrend_order,
            "original_fs": original_fs,
            "cleaned_fs": original_fs / max(1, params.downsample),
            "notes": "auto" ,
        }
    )

    try:
        # Resolve sensor list (dynamic)
        if sensor_names is None:
            sensors_rows = list_sensors(structure_id)
            sensor_names = [r.get("name") for r in sensors_rows if r.get("name")]
        if not sensor_names:
            raise RuntimeError("No sensors found for this structure (metadata_api).")

        # Query per-sensor data
        times_ref = None
        series = []
        min_len = None
        points_read = 0
        for s in sensor_names:
            t_ns, vals = _query_sensor(settings.RAW_BUCKET, structure_name, s, start, stop)
            if times_ref is None:
                times_ref = t_ns
            # Align by trimming to common length
            n = int(min(len(vals), len(t_ns), len(times_ref)))
            if min_len is None:
                min_len = n
            else:
                min_len = min(min_len, n)
            series.append(vals)
            points_read += len(vals)

        if times_ref is None or min_len is None or min_len < 2:
            raise RuntimeError("No raw data found for upload window")

        X = np.stack([v[:min_len] for v in series], axis=1)
        times_ref = times_ref[:min_len]

        cleaned, fs_new = clean_signal_matrix(
            X,
            fc=float(original_fs),
            ss=int(params.downsample),
            hp=params.hp,
            lp=params.lp,
            filter_order=int(params.filter_order),
            detrend_order=int(params.detrend_order),
        )

        ss = max(1, int(params.downsample))
        times_new = times_ref[::ss][: cleaned.shape[0]]

        points_written = 0
        for idx, s in enumerate(sensor_names):
            points_written += _write_cleaned_series(structure_name, s, times_new, cleaned[:, idx])

        complete_cleaning_run(cleaning_id, points_read=points_read, points_written=points_written)

        return {
            "status": "success",
            "upload_id": upload_id,
            "cleaning_id": cleaning_id,
            "points_read": points_read,
            "points_written": points_written,
            "cleaned_fs": float(fs_new),
            "structure_name": structure_name,
            "start_time": start.isoformat(),
            "end_time": stop.isoformat(),
        }

    except Exception as e:
        logger.exception(f"clean_upload_influx failed for upload_id={upload_id}: {e}")
        try:
            fail_cleaning_run(cleaning_id, str(e))
        except Exception:
            pass
        raise
