# English comments only
from __future__ import annotations

import json
import os
import threading
import time
from typing import Optional

import paho.mqtt.client as mqtt
from redis import Redis

from .core.config import HMConfig
from .notify.telegram import notify
from .metrics import inc, mark
from .core.redis import get_redis  # <<< expose at module level for tests to monkeypatch


def process_vitals(data: dict) -> dict:
    """
    Ingest a single vitals record:
    - Build plaintext payload {metric,value,unit,ts,source}
    - If HM_AES_KEY is set (base64 or hex, 32 bytes), encrypt with AES-GCM
      and store envelope {"enc":"aesgcm","iv","ct"}; otherwise store {"enc":"none","json"}
    - XADD to stream (ENV VITALS_STREAM or "vitals"), fields include explicit "ts" and "blob"
    - Return {"ok": True, "ts": ts, "critical": bool}
    """
    import base64, secrets
    from cryptography.hazmat.primitives.ciphers.aead import AESGCM

    now = int(time.time())
    metric = str(data.get("metric", "")).lower()
    value = float(data.get("value", 0))
    unit = str(data.get("unit", ""))
    ts = int(data.get("ts") or now)
    source = str(data.get("source") or "api")
    payload = {"metric": metric, "value": value, "unit": unit, "ts": ts, "source": source}

    hr_high = float(os.getenv("HR_HIGH", "120"))
    spo2_low = float(os.getenv("SPO2_LOW", "90"))
    temp_high = float(os.getenv("TEMP_HIGH", "38.5"))
    critical = (metric == "hr" and value >= hr_high) or (metric == "spo2" and value <= spo2_low) or (metric in ("temp", "temperature") and value >= temp_high)

    key_raw = os.getenv("HM_AES_KEY", "").strip()
    if key_raw:
        try:
            key = base64.b64decode(key_raw)
        except Exception:
            key = bytes.fromhex(key_raw)
        if len(key) != 32:
            envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}
        else:
            iv = secrets.token_bytes(12)
            aead = AESGCM(key)
            ct = aead.encrypt(iv, json.dumps(payload, separators=(",", ":")).encode("utf-8"), None)
            envelope = {"enc": "aesgcm", "iv": base64.b64encode(iv).decode(), "ct": base64.b64encode(ct).decode()}
    else:
        envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}

    r = get_redis()
    stream = os.getenv("VITALS_STREAM", "vitals")
    # write explicit ts plus blob (compatible with export logic)
    r.xadd(stream, {"ts": ts, "blob": json.dumps(envelope, separators=(",", ":"))})

    try:
        if critical:
            notify(f"[HM] critical {metric}={value}{unit} from {source}", priority="high")
    except Exception as e:
        print(f"[hm] notify failed: {e}")

    try:
        inc("alerts_sent", 1 if critical else 0)
        mark("last_msg_ts", ts)
    except Exception:
        pass

    return {"ok": True, "ts": ts, "critical": critical}


_running = False
_thread: Optional[threading.Thread] = None
_last_connect_rc: Optional[int] = None
_last_msg_ts: Optional[int] = None


def _redis_set_cooldown(r: Redis, key: str, ttl_sec: int) -> bool:
    return r.set(name=key, value="1", ex=ttl_sec, nx=True) is True


def _cooldown_key(prefix: str) -> str:
    return f"hm:cd:{prefix}"


def _update_last_measure(r: Redis, ts: int) -> None:
    r.set("hm:last_measure_ts", ts, ex=60 * 60 * 24)


def _get_last_measure(r: Redis) -> Optional[int]:
    v = r.get("hm:last_measure_ts")
    return int(v) if v else None


def _connect_client() -> mqtt.Client:
    host = os.getenv("MQTT_HOST", "mosquitto")
    port = int(os.getenv("MQTT_PORT", "1883"))
    keepalive = 30
    c = mqtt.Client(client_id=f"hm-sub-{int(time.time())}")
    c.connect(host, port, keepalive=keepalive)
    return c


def _store_vitals_to_stream(p: dict) -> None:
    """Mirror process_vitals() store logic for MQTT path to ensure Redis has entries."""
    try:
        import base64, secrets
        from cryptography.hazmat.primitives.ciphers.aead import AESGCM
    except Exception as e:
        print(f"[hm] store-to-stream imports failed: {e}")
        return

    now = int(time.time())
    metric = str(p.get("metric", "")).lower()
    value = float(p.get("value", 0)) if p.get("value") is not None else 0.0
    unit = str(p.get("unit", ""))
    ts = int(p.get("ts") or now)
    source = str(p.get("source") or "mqtt")
    payload = {"metric": metric, "value": value, "unit": unit, "ts": ts, "source": source}

    key_raw = os.getenv("HM_AES_KEY", "").strip()
    if key_raw:
        try:
            key = base64.b64decode(key_raw)
        except Exception:
            key = bytes.fromhex(key_raw)
        if len(key) == 32:
            iv = secrets.token_bytes(12)
            aead = AESGCM(key)
            ct = aead.encrypt(iv, json.dumps(payload, separators=(",", ":")).encode("utf-8"), None)
            envelope = {"enc": "aesgcm", "iv": base64.b64encode(iv).decode(), "ct": base64.b64encode(ct).decode()}
        else:
            envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}
    else:
        envelope = {"enc": "none", "json": json.dumps(payload, separators=(",", ":"))}

    try:
        r = get_redis()
        stream = os.getenv("VITALS_STREAM", "vitals")
        r.xadd(stream, {"ts": ts, "blob": json.dumps(envelope, separators=(",", ":"))})
    except Exception as e:
        print(f"[hm] redis xadd failed: {e}")


def _on_message(cfg: HMConfig, r: Redis, _client: mqtt.Client, msg: mqtt.MQTTMessage) -> None:
    global _last_msg_ts
    try:
        p = json.loads(msg.payload.decode("utf-8"))
    except Exception:
        return

    try:
        _store_vitals_to_stream(p)
    except Exception as e:
        print(f"[hm] store-to-stream error: {e}")

    ts = int(p.get("ts", int(time.time())))
    metric = str(p.get("metric", "")).lower()
    _last_msg_ts = ts

    if metric == "hr":
        try:
            val = float(p.get("value"))
        except Exception:
            return
        _update_last_measure(r, ts)
        if val >= cfg.HR_HIGH:
            ttl = cfg.HR_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("hr_high"), ttl):
                suffix = f" source={p.get('source')}" if p.get("source") else ""
                ok = False
                try:
                    ok = notify(f"Health alert: hr={val:.1f} bpm{suffix}", priority="high")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("emergencies_sent", 1)
                    mark("last_emergency_ts", ts)
        return

    if metric in ("spo2",):
        try:
            val = float(p.get("value"))
        except Exception:
            return
        if val <= cfg.SPO2_LOW:
            ttl = cfg.SPO2_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("spo2_low"), ttl):
                unit = str(p.get("unit", "%")) or "%"
                suffix = f" source={p.get('source')}" if p.get("source") else ""
                ok = False
                try:
                    ok = notify(f"Health alert: spo2={val:.0f} {unit}{suffix}", priority="high")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("emergencies_sent", 1)
                    inc("spo2_events", 1)
                    mark("last_emergency_ts", ts)
        return

    if metric in ("temp", "temperature"):
        try:
            val = float(p.get("value"))
        except Exception:
            return
        if val >= cfg.TEMP_HIGH:
            ttl = cfg.TEMP_ALERT_COOLDOWN_MIN * 60
            if _redis_set_cooldown(r, _cooldown_key("temp_high"), ttl):
                unit = str(p.get("unit", "C")) or "C"
                suffix = f" source={p.get('source')}" if p.get("source") else ""
                ok = False
                try:
                    ok = notify(f"Health alert: temp={val:.1f} °{unit}{suffix}", priority="normal")
                except Exception as e:
                    print(f"[hm] notify failed: {e}")
                if ok:
                    inc("alerts_sent", 1)
                    inc("temp_events", 1)
                    mark("last_alert_ts", ts)
        return


def _worker(cfg: HMConfig, r: Redis) -> None:
    global _running, _last_connect_rc
    while _running:
        try:
            c = _connect_client()
            _last_connect_rc = 0
            c.subscribe("vitals/ingest", qos=0)
            c.on_message = lambda _c, _u, m: _on_message(cfg, r, _c, m)
            c.loop_forever()
        except Exception as e:
            _last_connect_rc = -1
            print(f"[hm] mqtt loop error: {e}")
            time.sleep(2.0)


def start(cfg: HMConfig, r: Redis) -> None:
    global _running, _thread
    if _running:
        return
    _running = True
    _thread = threading.Thread(target=_worker, args=(cfg, r), name="hm-subscriber", daemon=True)
    _thread.start()
    print("[subscriber] started")


def stop() -> None:
    global _running
    _running = False


def status() -> dict:
    return {"running": _running, "last_connect_rc": _last_connect_rc, "last_msg_ts": _last_msg_ts}


def check_no_measure_and_remind(cfg: HMConfig, r: Redis) -> Optional[str]:
    now = int(time.time())
    last = r.get("hm:last_measure_ts")
    last = int(last) if last else None
    need = (last is None) or (now - last >= cfg.NO_MEASURE_WINDOW_MIN * 60)
    if not need:
        return None

    cd = cfg.REMINDER_COOLDOWN_MIN * 60
    if r.set(name="hm:cd:no_measure", value="1", ex=cd, nx=True):
        ok = False
        try:
            ok = notify(f"Reminder: no hr measurement in last {cfg.NO_MEASURE_WINDOW_MIN} minutes", priority="normal")
        except Exception as e:
            print(f"[hm] notify failed: {e}")
        if ok:
            inc("alerts_sent", 1)
            mark("last_alert_ts", now)
        return "reminded(no-measure)"
    return None
