from __future__ import annotations

import os
import secrets
import tempfile
from dataclasses import dataclass
from datetime import timezone
from typing import Any, Dict, Optional

import httpx
import numpy as np
import plotly.graph_objects as go
from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from scipy.signal import welch

from shm_shared.cleaning import clean_signal_matrix
from shm_shared.logging import setup_logging
from shm_shared.oma import identify_modes_fdd
from shm_shared.parsing import parse_txt_file
from shm_shared.settings import settings

logger = setup_logging("management.ui")

app = FastAPI(title="SHM Management UI")
templates = Jinja2Templates(directory="services/management_ui/templates")
app.mount("/static", StaticFiles(directory="services/management_ui/static"), name="static")

security = HTTPBasic()


def verify_auth(credentials: HTTPBasicCredentials = Depends(security)):
    """Verify HTTP Basic Authentication credentials."""
    correct_username = secrets.compare_digest(
        credentials.username,
        settings.UI_USERNAME
    )
    correct_password = secrets.compare_digest(
        credentials.password,
        settings.UI_PASSWORD
    )
    if not (correct_username and correct_password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid credentials",
            headers={"WWW-Authenticate": "Basic"},
        )
    return credentials.username


def _meta_base() -> str:
    return (settings.METADATA_API_URL or "http://metadata_api:8001").rstrip("/")


async def _get_json(url: str, default: Any):
    try:
        async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
            r = await client.get(url)
            if r.status_code >= 400:
                return default
            return r.json()
    except Exception:
        return default


async def _post_json(url: str, payload: dict | None, default: Any):
    try:
        async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
            r = await client.post(url, json=payload)
            if r.status_code >= 400:
                return default, r.text
            return r.json() if r.content else default, None
    except Exception as e:
        return default, str(e)


async def _put_json(url: str, payload: dict, default: Any):
    try:
        async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
            r = await client.put(url, json=payload)
            if r.status_code >= 400:
                return default, r.text
            return r.json() if r.content else default, None
    except Exception as e:
        return default, str(e)


async def _delete(url: str) -> tuple[bool, Optional[str]]:
    try:
        async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
            r = await client.delete(url)
            if r.status_code in (200, 202, 204):
                return True, None
            return False, r.text
    except Exception as e:
        return False, str(e)


@app.get("/healthz")
def health():
    return {"status": "ok"}


# -----------------------------
# Home
# -----------------------------
@app.get("/", response_class=HTMLResponse)
async def home(request: Request, msg: Optional[str] = None, edit: Optional[int] = None, show_sensors: Optional[int] = None, username: str = Depends(verify_auth)):
    meta = _meta_base()

    edit_mode = bool(edit == 1)
    show_sensors_flag = bool(show_sensors == 1)

    # Configs from metadata service (stored in Postgres pipeline_config)
    cleaning_cfg_raw = await _get_json(f"{meta}/config/cleaning", {})
    system_cfg = await _get_json(f"{meta}/config/system", {})
    analysis_cfg_raw = await _get_json(f"{meta}/config/analysis", {})

    # Fill missing values from env defaults (so UI always shows current effective config)
    cleaning_cfg = {
        "hp": cleaning_cfg_raw.get("hp", settings.CLEANING_HP),
        "lp": cleaning_cfg_raw.get("lp", settings.CLEANING_LP),
        "downsample": cleaning_cfg_raw.get("downsample", settings.CLEANING_DOWNSAMPLE),
        "filter_order": cleaning_cfg_raw.get("filter_order", settings.CLEANING_FILTER_ORDER),
        "detrend_order": cleaning_cfg_raw.get("detrend_order", settings.CLEANING_DETREND_ORDER),
    }
    analysis_cfg = {
        "freq_min": analysis_cfg_raw.get("freq_min", settings.ANALYSIS_FREQ_MIN),
        "freq_max": analysis_cfg_raw.get("freq_max", settings.ANALYSIS_FREQ_MAX),
        "max_modes": analysis_cfg_raw.get("max_modes", settings.ANALYSIS_MAX_MODES),
        "mac_threshold": analysis_cfg_raw.get("mac_threshold", settings.ANALYSIS_MAC_THRESHOLD),
        "freq_rel_tol": analysis_cfg_raw.get("freq_rel_tol", settings.ANALYSIS_FREQ_REL_TOL),
        "cluster_tol": analysis_cfg_raw.get("cluster_tol", settings.ANALYSIS_CLUSTER_TOL),
    }


    # IMPORTANT: call /structures/ with trailing slash to avoid 307 redirect edge cases
    structures = await _get_json(f"{meta}/structures/", [])
    default_structure = system_cfg.get("monitored_structure") or settings.MONITORED_STRUCTURE

    # Sensors panel (optional)
    sensors: list[dict] = []
    sensors_error: Optional[str] = None
    if show_sensors_flag:
        # Find structure_id by name
        structure_id = None
        for s in structures or []:
            try:
                if (s.get("name") if isinstance(s, dict) else getattr(s, "name", None)) == default_structure:
                    structure_id = s.get("structure_id") if isinstance(s, dict) else getattr(s, "structure_id", None)
                    break
            except Exception:
                pass

        if structure_id is None and default_structure:
            # fallback: resolve via by-name endpoint
            s_obj = await _get_json(f"{meta}/structures/by-name/{default_structure}", None)
            if isinstance(s_obj, dict):
                structure_id = s_obj.get("structure_id")

        if structure_id is not None:
            sensors = await _get_json(f"{meta}/sensors/?structure_id={int(structure_id)}", [])
        else:
            sensors_error = "Default structure not found in DB."

    return templates.TemplateResponse(
        "index.html",
        {
            "request": request,
            "msg": msg,
            "edit_mode": edit_mode,
            "show_sensors": show_sensors_flag,
            "sensors": sensors,
            "sensors_error": sensors_error,
            "cleaning_cfg": cleaning_cfg,
            "analysis_cfg": analysis_cfg,
            "system_cfg": system_cfg,
            "default_structure": default_structure,
            "structures": structures,
            "meta_url": meta,
                    },
    )


@app.post("/config/cleaning")
async def update_cleaning_config(
    hp: Optional[str] = Form(default=None),
    lp: Optional[str] = Form(default=None),
    downsample: Optional[str] = Form(default=None),
    filter_order: Optional[str] = Form(default=None),
    detrend_order: Optional[str] = Form(default=None),
    username: str = Depends(verify_auth),
):
    payload: Dict[str, Any] = {}
    for k, v in {
        "hp": hp,
        "lp": lp,
        "downsample": downsample,
        "filter_order": filter_order,
        "detrend_order": detrend_order,
    }.items():
        if v is not None and str(v).strip() != "":
            payload[k] = v

    meta = _meta_base()
    _, err = await _put_json(f"{meta}/config/cleaning", payload, None)
    msg = "✅ Cleaning config updated" if err is None else f"❌ Cleaning config update failed: {err}"
    return RedirectResponse(url=f"/?msg={msg}", status_code=303)


@app.post("/config/analysis")
async def update_analysis_config(
    freq_min: Optional[str] = Form(default=None),
    freq_max: Optional[str] = Form(default=None),
    max_modes: Optional[str] = Form(default=None),
    mac_threshold: Optional[str] = Form(default=None),
    freq_rel_tol: Optional[str] = Form(default=None),
    cluster_tol: Optional[str] = Form(default=None),
    username: str = Depends(verify_auth),
):
    payload: Dict[str, Any] = {}
    for k, v in {
        "freq_min": freq_min,
        "freq_max": freq_max,
        "max_modes": max_modes,
        "mac_threshold": mac_threshold,
        "freq_rel_tol": freq_rel_tol,
        "cluster_tol": cluster_tol,
    }.items():
        if v is not None and str(v).strip() != "":
            payload[k] = v

    meta = _meta_base()
    _, err = await _put_json(f"{meta}/config/analysis", payload, None)
    msg = "✅ Analysis config updated" if err is None else f"❌ Analysis config update failed: {err}"
    return RedirectResponse(url=f"/?msg={msg}", status_code=303)


@app.post("/config/system")
async def update_system_config(monitored_structure: str = Form(...), username: str = Depends(verify_auth)):
    meta = _meta_base()
    _, err = await _put_json(f"{meta}/config/system", {"monitored_structure": monitored_structure}, None)
    msg = "✅ Default structure updated" if err is None else f"❌ Default structure update failed: {err}"
    return RedirectResponse(url=f"/?msg={msg}", status_code=303)


@app.post("/structures/add")
async def add_structure(
    name: str = Form(...),
    type: Optional[str] = Form(default=None),
    location: Optional[str] = Form(default=None),
    latitude: Optional[str] = Form(default=None),
    longitude: Optional[str] = Form(default=None),
    description: Optional[str] = Form(default=None),
    username: str = Depends(verify_auth),
):
    meta = _meta_base()
    payload: Dict[str, Any] = {"name": name}
    if type and type.strip():
        payload["type"] = type.strip()
    if location and location.strip():
        payload["location"] = location.strip()
    if latitude and latitude.strip():
        try:
            payload["latitude"] = float(latitude)
        except Exception:
            return RedirectResponse(url="/?msg=❌ Invalid latitude", status_code=303)
    if longitude and longitude.strip():
        try:
            payload["longitude"] = float(longitude)
        except Exception:
            return RedirectResponse(url="/?msg=❌ Invalid longitude", status_code=303)
    if description and description.strip():
        payload["description"] = description.strip()

    _, err = await _post_json(f"{meta}/structures/", payload, None)
    msg = "✅ Structure added" if err is None else f"❌ Add structure failed: {err}"
    return RedirectResponse(url=f"/?msg={msg}", status_code=303)


@app.post("/structures/delete")
async def delete_structure(structure_name: str = Form(...), username: str = Depends(verify_auth)):
    meta = _meta_base()

    # Resolve to ID
    obj = await _get_json(f"{meta}/structures/by-name/{structure_name}", None)
    if not isinstance(obj, dict) or "structure_id" not in obj:
        return RedirectResponse(url="/?msg=❌ Structure not found", status_code=303)

    ok, err = await _delete(f"{meta}/structures/{obj['structure_id']}")
    msg = "✅ Structure deleted" if ok else f"❌ Delete failed: {err}"
    return RedirectResponse(url=f"/?msg={msg}", status_code=303)




# -----------------------------
# Manual upload + local outputs (no DB writes)
# -----------------------------
@dataclass
class ManualArtifact:
    cleaned_path: Optional[str] = None
    analysis_json_path: Optional[str] = None


@dataclass
class ManualUpload:
    raw_path: str
    channel_names: list[str]


_MANUAL_ARTIFACTS: Dict[str, ManualArtifact] = {}
_MANUAL_UPLOADS: Dict[str, ManualUpload] = {}



def _moving_average(x: np.ndarray, win: int) -> np.ndarray:
    win = int(max(1, win))
    kernel = np.ones(win) / win
    return np.convolve(x, kernel, mode="same")


def _fig_to_html(fig: "go.Figure") -> str:
    # include_plotlyjs="cdn" keeps HTML small; browser loads JS from CDN.
    return fig.to_html(full_html=False, include_plotlyjs="cdn")


def _make_drift_plot_html(
    t_raw: np.ndarray,
    X_raw: np.ndarray,
    t_cln: np.ndarray,
    X_cln: np.ndarray,
    fs_raw: float,
    fs_cln: float,
    labels: list[str],
    max_lines: int = 12,
) -> tuple[str, str]:
    """Return (html, note). If too many channels, plot first max_lines."""
    n_ch = int(X_raw.shape[1])
    n_plot = min(n_ch, int(max_lines))
    note = ""
    if n_ch > n_plot:
        note = f"Showing first {n_plot} of {n_ch} channels."

    fig = go.Figure()
    for i in range(n_plot):
        win_r = int(fs_raw * 10)
        win_c = int(fs_cln * 10)
        tr = _moving_average(X_raw[:, i], win_r)
        tc = _moving_average(X_cln[:, i], win_c)
        fig.add_trace(go.Scatter(x=t_raw, y=tr, mode="lines", name=f"raw {labels[i]}"))
        fig.add_trace(go.Scatter(x=t_cln, y=tc, mode="lines", name=f"clean {labels[i]}"))
    fig.update_layout(
        title="Baseline/Drift comparison (moving average, 10s)",
        xaxis_title="time [s]",
        yaxis_title="moving avg (10s)",
        legend_title="Series",
        height=520,
        margin=dict(l=40, r=20, t=60, b=40),
    )
    return _fig_to_html(fig), note


def _make_psd_plot_html(
    X_raw: np.ndarray,
    X_cln: np.ndarray,
    fs_raw: float,
    fs_cln: float,
    hp: float,
    lp: float,
    labels: list[str],
    max_lines: int = 12,
) -> tuple[str, str]:
    """Welch PSD comparison for up to max_lines channels. Return (html, note)."""
    n_ch = int(X_raw.shape[1])
    n_plot = min(n_ch, int(max_lines))
    note = ""
    if n_ch > n_plot:
        note = f"Showing first {n_plot} of {n_ch} channels."

    fig = go.Figure()
    for i in range(n_plot):
        x_r = X_raw[:, i]
        x_c = X_cln[:, i]
        f_r, P_r = welch(x_r, fs=fs_raw, nperseg=min(len(x_r), max(256, int(fs_raw * 8))), detrend="constant")
        f_c, P_c = welch(x_c, fs=fs_cln, nperseg=min(len(x_c), max(256, int(fs_cln * 8))), detrend="constant")

        fig.add_trace(go.Scatter(x=f_r, y=P_r, mode="lines", name=f"raw PSD {labels[i]}"))
        fig.add_trace(go.Scatter(x=f_c, y=P_c, mode="lines", name=f"clean PSD {labels[i]}"))

    fig.add_vline(x=float(hp), line_dash="dash", annotation_text=f"hp={hp}Hz", annotation_position="top left")
    fig.add_vline(x=float(lp), line_dash="dash", annotation_text=f"lp={lp}Hz", annotation_position="top right")

    fig.update_layout(
        title="Welch PSD comparison (log scale)",
        xaxis_title="frequency [Hz]",
        yaxis_title="PSD",
        yaxis_type="log",
        height=520,
        margin=dict(l=40, r=20, t=60, b=40),
    )
    fig.update_xaxes(range=[0, min(50, float(max(np.max(f_r), np.max(f_c))))])
    return _fig_to_html(fig), note


def _pick_channel_indices(all_labels: list[str], requested: list[str]) -> list[int]:
    if not requested:
        return []
    idx = []
    label_to_idx = {str(n): i for i, n in enumerate(all_labels)}
    for name in requested:
        if name in label_to_idx:
            idx.append(label_to_idx[name])
    seen = set()
    out = []
    for i in idx:
        if i not in seen:
            seen.add(i)
            out.append(i)
    return out


@app.get("/manual", response_class=HTMLResponse)
async def manual_page(request: Request, token: Optional[str] = None, msg: Optional[str] = None, username: str = Depends(verify_auth)):
    art = _MANUAL_ARTIFACTS.get(token) if token else None
    upload = _MANUAL_UPLOADS.get(token) if token else None
    channels = upload.channel_names if upload else None
    selected_channels = request.query_params.get("channels", "")
    return templates.TemplateResponse(
        "manual.html",
        {
            "request": request,
            "token": token,
            "msg": msg,
            "art": art,
            "channels": channels,
            "selected_channels": selected_channels,
        },
    )


@app.post("/manual/inspect", response_class=HTMLResponse)
async def manual_inspect(request: Request, file: UploadFile = File(...), username: str = Depends(verify_auth)):
    suffix = os.path.splitext(file.filename or "upload.txt")[1] or ".txt"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        raw_path = tmp.name
        tmp.write(await file.read())

    token = secrets.token_urlsafe(12)
    try:
        parsed = parse_txt_file(raw_path)
        labels = parsed.get("sensor_cols")
        _MANUAL_UPLOADS[token] = ManualUpload(raw_path=raw_path, channel_names=list(labels))
        return RedirectResponse(url=f"/manual?token={token}&msg=✅ Channels loaded", status_code=303)
    except Exception as e:
        try:
            os.remove(raw_path)
        except Exception:
            pass
        return templates.TemplateResponse(
            "manual.html",
            {
                "request": request,
                "token": None,
                "msg": f"❌ Failed to read file: {e}",
                "art": None,
                "channels": None,
                "selected_channels": "",
            },
        )


@app.post("/manual/run", response_class=HTMLResponse)
async def manual_run(
    request: Request,
    token: Optional[str] = Form(default=None),
    file: UploadFile | None = File(default=None),
    action: str = Form("clean_and_analyze"),
    channels: Optional[str] = Form(default=None),
    username: str = Depends(verify_auth),
):
    meta = _meta_base()
    cleaning_cfg = await _get_json(f"{meta}/config/cleaning", {})
    analysis_cfg = await _get_json(f"{meta}/config/analysis", {})

    hp = float(cleaning_cfg.get("hp") or settings.CLEANING_HP)
    lp = float(cleaning_cfg.get("lp") or settings.CLEANING_LP)
    downsample = int(float(cleaning_cfg.get("downsample") or settings.CLEANING_DOWNSAMPLE))
    filter_order = int(float(cleaning_cfg.get("filter_order") or settings.CLEANING_FILTER_ORDER))
    detrend_order = int(float(cleaning_cfg.get("detrend_order") or settings.CLEANING_DETREND_ORDER))

    freq_min = float(analysis_cfg.get("freq_min") or settings.ANALYSIS_FREQ_MIN)
    freq_max = float(analysis_cfg.get("freq_max") or settings.ANALYSIS_FREQ_MAX)
    max_modes = int(float(analysis_cfg.get("max_modes") or settings.ANALYSIS_MAX_MODES))

    raw_path = None
    upload = _MANUAL_UPLOADS.get(token) if token else None
    if upload:
        raw_path = upload.raw_path
    elif file is not None:
        suffix = os.path.splitext(file.filename or "upload.txt")[1] or ".txt"
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
            raw_path = tmp.name
            tmp.write(await file.read())
        token = token or secrets.token_urlsafe(12)
    else:
        return RedirectResponse(url="/manual?msg=❌ Please upload a file first (Inspect channels)", status_code=303)

    requested = []
    if channels:
        requested = [c.strip() for c in channels.split(",") if c.strip()]

    cleaned_path = None
    analysis_json_path = None
    drift_html = None
    psd_html = None
    drift_note = ""
    psd_note = ""
    modes = []

    try:
        parsed = parse_txt_file(raw_path)
        X_raw_full = parsed["data"].astype(np.float64)
        t_raw = parsed["t_rel"].astype(np.float64)
        fs_raw = float(parsed["sampling_hz"])
        labels_full = parsed.get("sensor_cols") or [f"col{i+1}" for i in range(X_raw_full.shape[1])]

        if token and token not in _MANUAL_UPLOADS:
            _MANUAL_UPLOADS[token] = ManualUpload(raw_path=raw_path, channel_names=list(labels_full))

        cleaned_full, fs_cln = clean_signal_matrix(
            X_raw_full,
            fc=fs_raw,
            ss=max(1, downsample),
            hp=hp,
            lp=lp,
            filter_order=filter_order,
            detrend_order=detrend_order,
        )
        t_cln = t_raw[:: max(1, downsample)][: cleaned_full.shape[0]]

        token = token or secrets.token_urlsafe(12)
        cleaned_path = os.path.join(tempfile.gettempdir(), f"cleaned_{token}.txt")
        dt_new = (t_cln[1] - t_cln[0]) if len(t_cln) > 1 else (1.0 / float(fs_cln))
        header = "\t".join(["tempo"] + [str(c) for c in labels_full])
        lines = [
            "t0\t" + parsed["start_time"].astimezone(timezone.utc).strftime("%Y/%m/%d %H:%M:%S"),
            f"dt\t{dt_new:.6f}",
            header,
            "",
        ]
        for ti, row in zip(t_cln, cleaned_full):
            vals = "\t".join([f"{v:.8f}" for v in row.tolist()])
            lines.append(f"{ti:.6f}\t{vals}")
        open(cleaned_path, "w", encoding="utf-8").write("\n".join(lines))

        idx = _pick_channel_indices(list(labels_full), requested)
        if idx:
            X_raw = X_raw_full[:, idx]
            X_cln = cleaned_full[:, idx]
            labels = [labels_full[i] for i in idx]
        else:
            X_raw = X_raw_full
            X_cln = cleaned_full
            labels = list(labels_full)

        drift_html, drift_note = _make_drift_plot_html(t_raw, X_raw, t_cln, X_cln, fs_raw, float(fs_cln), labels=labels)
        psd_html, psd_note = _make_psd_plot_html(X_raw, X_cln, fs_raw, float(fs_cln), hp=hp, lp=lp, labels=labels)

        if action == "clean_and_analyze":
            est = identify_modes_fdd(cleaned_full, fs=float(fs_cln), freq_min=freq_min, freq_max=freq_max, max_modes=max_modes)
            modes = [{"frequency_hz": float(m.frequency_hz), "quality": float(m.quality), "damping_ratio": m.damping_ratio} for m in est]
            analysis_json_path = os.path.join(tempfile.gettempdir(), f"analysis_{token}.json")
            import json

            open(analysis_json_path, "w", encoding="utf-8").write(json.dumps({"modes": modes}, indent=2))

        _MANUAL_ARTIFACTS[token] = ManualArtifact(cleaned_path=cleaned_path, analysis_json_path=analysis_json_path)
        msg = "✅ Manual run complete. Downloads are available below."
        return templates.TemplateResponse(
            "manual.html",
            {
                "request": request,
                "token": token,
                "msg": msg,
                "art": _MANUAL_ARTIFACTS[token],
                "channels": _MANUAL_UPLOADS.get(token).channel_names if token in _MANUAL_UPLOADS else list(labels_full),
                "selected_channels": channels or "",
                "drift_html": drift_html,
                "psd_html": psd_html,
                "drift_note": drift_note,
                "psd_note": psd_note,
                "modes": modes,
            },
        )
    finally:
        if upload is None:
            try:
                if raw_path:
                    os.remove(raw_path)
            except Exception:
                pass


@app.get("/manual/download/{token}/{kind}")
def manual_download(token: str, kind: str, username: str = Depends(verify_auth)):
    art = _MANUAL_ARTIFACTS.get(token)
    if not art:
        raise HTTPException(status_code=404, detail="Unknown token")
    if kind == "cleaned":
        if not art.cleaned_path or not os.path.exists(art.cleaned_path):
            raise HTTPException(status_code=404, detail="Missing cleaned file")
        return FileResponse(art.cleaned_path, filename=f"cleaned_{token}.txt")
    if kind == "analysis":
        if not art.analysis_json_path or not os.path.exists(art.analysis_json_path):
            raise HTTPException(status_code=404, detail="Missing analysis result")
        return FileResponse(art.analysis_json_path, filename=f"analysis_{token}.json")
    raise HTTPException(status_code=400, detail="Unknown kind")
