# English comments only
from __future__ import annotations

import json
import os
import threading
import time
from typing import Optional

import paho.mqtt.client as mqtt
from redis import Redis

from .core.config import HMConfig
from .notify.telegram import notify
from .metrics import inc, mark
# English comments only
def process_vitals(data: dict) -> dict:
    """
    Ingest a single vitals record:
    - Build plaintext payload {metric,value,unit,ts,source}
    - If HM_AES_KEY is set (base64 or hex, 32 bytes), encrypt with AES-GCM
      and store envelope {"enc":"aesgcm","iv":b64,"ct":b64}; otherwise store {"enc":"none","json":raw_json}
    - XADD to stream (ENV VITALS_STREAM or "vitals"), field "blob" = json.dumps(envelope)
    - Return {"ok": True, "ts": ts, "critical": bool} where critical is decided by thresholds
    """
    import base64, json, os, secrets
    from cryptography.hazmat.primitives.ciphers.aead import AESGCM
    from .core.redis import get_redis

    # Normalize payload
    now = int(time.time())
    metric = str(data.get("metric", "")).lower()
    value = float(data.get("value", 0))
    unit = str(data.get("unit", ""))
    ts = int(data.get("ts") or now)
    source = str(data.get("source") or "api")
    payload = {"metric": metric, "value": value, "unit": unit, "ts": ts, "source": source}

    # Decide critical by simple thresholds (ENV overrides)
    hr_high = float(os.getenv("HR_HIGH", "120"))
    spo2_low = float(os.getenv("SPO2_LOW", "90"))
    temp_high = float(os.getenv("TEMP_HIGH", "38.5"))
    critical = False
    if metric == "hr" and value >= hr_high:
        critical = True
    elif metric == "spo2" and value <= spo2_low:
        critical = True
    elif metric in ("temp", "temperature") and value >= temp_high:
        critical = True

    # Build envelope (encrypt if key provided)
    key_raw = os.getenv("HM_AES_KEY", "").strip()
    if key_raw:
        # parse base64 -> hex
        try:
            key = base64.b64decode(key_raw)
        except Exception:
            key = bytes.fromhex(key_raw)
        if len(key) != 32:
            # Fallback to plaintext if key invalid
            envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}
        else:
            iv = secrets.token_bytes(12)
            aead = AESGCM(key)
            ct = aead.encrypt(iv, json.dumps(payload, separators=(",", ":")).encode("utf-8"), None)
            envelope = {
                "enc": "aesgcm",
                "iv": base64.b64encode(iv).decode(),
                "ct": base64.b64encode(ct).decode(),
            }
    else:
        envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}

    # Write to Redis Stream
    r = get_redis()
    stream = os.getenv("VITALS_STREAM", "vitals")
    r.xadd(stream, {"blob": json.dumps(envelope, separators=(",", ":"))})

    # Optional notify on critical
    try:
        if critical:
            notify(f"[HM] critical {metric}={value}{unit} from {source}", priority="high")
    except Exception as e:
        print(f"[hm] notify failed: {e}")

    # Update metrics (best-effort)
    try:
        inc("alerts_sent", 1 if critical else 0)
        mark("last_msg_ts", ts)
    except Exception:
        pass

    return {"ok": True, "ts": ts, "critical": critical}

_running = False
_thread: Optional[threading.Thread] = None
_last_connect_rc: Optional[int] = None
_last_msg_ts: Optional[int] = None


def _redis_set_cooldown(r: Redis, key: str, ttl_sec: int) -> bool:
    return r.set(name=key, value="1", ex=ttl_sec, nx=True) is True


def _cooldown_key(prefix: str) -> str:
    return f"hm:cd:{prefix}"


def _update_last_measure(r: Redis, ts: int) -> None:
    r.set("hm:last_measure_ts", ts, ex=60 * 60 * 24)


def _get_last_measure(r: Redis) -> Optional[int]:
    v = r.get("hm:last_measure_ts")
    return int(v) if v else None


def _connect_client() -> mqtt.Client:
    host = os.getenv("MQTT_HOST", "mosquitto")
    port = int(os.getenv("MQTT_PORT", "1883"))
    keepalive = 30
    c = mqtt.Client(client_id=f"hm-sub-{int(time.time())}")
    c.connect(host, port, keepalive=keepalive)
    return c


def _on_message(cfg: HMConfig, r: Redis, _client: mqtt.Client, msg: mqtt.MQTTMessage) -> None:
    global _last_msg_ts
    try:
        p = json.loads(msg.payload.decode("utf-8"))
    except Exception:
        return
    ts = int(p.get("ts", int(time.time())))
    metric = str(p.get("metric", "")).lower()
    _last_msg_ts = ts

    # HR high
    if metric == "hr":
        try:
            val = float(p.get("value"))
        except Exception:
            return
        _update_last_measure(r, ts)
        if val >= cfg.HR_HIGH:
            ttl = cfg.HR_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("hr_high"), ttl):
                suffix = ""
                if p.get("source"):
                    suffix = f" source={p.get('source')}"
                ok = False
                try:
                    ok = notify(f"Health alert: hr={val:.1f} bpm{suffix}", priority="high")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("emergencies_sent", 1)
                    mark("last_emergency_ts", ts)
        return

    # SpO2 low (<= threshold)
    if metric in ("spo2", "spO2".lower()):
        try:
            val = float(p.get("value"))
        except Exception:
            return
        if val <= cfg.SPO2_LOW:
            ttl = cfg.SPO2_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("spo2_low"), ttl):
                unit = str(p.get("unit", "%")) or "%"
                suffix = ""
                if p.get("source"):
                    suffix = f" source={p.get('source')}"
                ok = False
                try:
                    ok = notify(f"Health alert: spo2={val:.0f} {unit}{suffix}", priority="high")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("emergencies_sent", 1)
                    inc("spo2_events", 1)
                    mark("last_emergency_ts", ts)
        return

    # Temperature high (>= threshold)
    if metric in ("temp", "temperature"):
        try:
            val = float(p.get("value"))
        except Exception:
            return
        if val >= cfg.TEMP_HIGH:
            ttl = cfg.TEMP_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("temp_high"), ttl):
                unit = str(p.get("unit", "C")) or "C"
                suffix = ""
                if p.get("source"):
                    suffix = f" source={p.get('source')}"
                ok = False
                try:
                    ok = notify(f"Health alert: temp={val:.1f} °{unit}{suffix}", priority="normal")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("alerts_sent", 1)
                    inc("temp_events", 1)
                    mark("last_alert_ts", ts)
        return


def _worker(cfg: HMConfig, r: Redis) -> None:
    global _running, _last_connect_rc
    while _running:
        try:
            c = _connect_client()
            _last_connect_rc = 0
            c.subscribe("vitals/ingest", qos=0)
            c.on_message = lambda _c, _u, m: _on_message(cfg, r, _c, m)
            c.loop_forever()
        except Exception as e:
            _last_connect_rc = -1
            print(f"[hm] mqtt loop error: {e}")
            time.sleep(2.0)


def start(cfg: HMConfig, r: Redis) -> None:
    global _running, _thread
    if _running:
        return
    _running = True
    _thread = threading.Thread(target=_worker, args=(cfg, r), name="hm-subscriber", daemon=True)
    _thread.start()
    print("[subscriber] started")


def stop() -> None:
    global _running
    _running = False


def status() -> dict:
    return {
        "running": _running,
        "last_connect_rc": _last_connect_rc,
        "last_msg_ts": _last_msg_ts,
    }


def check_no_measure_and_remind(cfg: HMConfig, r: Redis) -> Optional[str]:
    now = int(time.time())
    last = _get_last_measure(r)
    need = False
    if last is None:
        need = True
    elif now - last >= cfg.NO_MEASURE_WINDOW_MIN * 60:
        need = True
    if not need:
        return None

    cd = cfg.REMINDER_COOLDOWN_MIN * 60
    if _redis_set_cooldown(r, _cooldown_key("no_measure"), cd):
        ok = False
        try:
            ok = notify(
                f"Reminder: no hr measurement in last {cfg.NO_MEASURE_WINDOW_MIN} minutes",
                priority="normal",
            )
        except Exception as e:
            print(f"[hm] notify failed: {e}")
        if ok:
            inc("alerts_sent", 1)
            mark("last_alert_ts", now)
        return "reminded(no-measure)"
    return None
