from __future__ import annotations

import time

from shm_shared.logging import setup_logging
from shm_shared.events import EventBus, parse_payload
from shm_shared.settings import settings

from .pipeline import clean_upload_influx

logger = setup_logging("cleaning.worker")

# Consume raw_ingested
bus_in = EventBus(stream=settings.STREAM_RAW_INGESTED, group=settings.GROUP_CLEANING)
# Publish cleaned_ready
bus_out = EventBus(stream=settings.STREAM_CLEANED_READY, group=None)

CONSUMER = "cleaning_worker"


def main():
    logger.info("Cleaning worker starting (Redis Streams)...")
    bus_in.ensure_group()

    # Opportunistically claim stale pending messages (crash recovery)
    for mid, fields in bus_in.claim_stale(CONSUMER, min_idle_ms=60000):
        try:
            _handle_message(mid, fields)
        except Exception as e:
            logger.exception(f"Failed to process claimed message {mid}: {e}")
            # Leave unacked for next attempt

    while True:
        msg = bus_in.consume_one(CONSUMER, block_ms=5000)
        if not msg:
            continue
        mid, fields = msg
        try:
            _handle_message(mid, fields)
        except Exception as e:
            logger.exception(f"Failed to process message {mid}: {e}")
            # Leave unacked so it can be claimed later (with backoff)
            time.sleep(1)


def _handle_message(message_id: str, fields: dict[str, str]):
    payload = parse_payload(fields)
    etype = payload.get("type")
    if etype != "raw_ingested":
        # Wrong stream payload; ACK to avoid blocking this stream.
        bus_in.ack(message_id)
        return

    upload_id = int(payload["upload_id"])
    logger.info(f"Received raw_ingested upload_id={upload_id}")

    sensor_names = payload.get("sensor_names")
    if isinstance(sensor_names, str):
        # tolerate stringified list
        sensor_names = None

    # Run cleaning (Influx Raw -> Cleaned)
    res = clean_upload_influx(
        upload_id,
        structure_name_override=payload.get("structure_name"),
        sensor_names=sensor_names if isinstance(sensor_names, list) else None,
        original_fs=float(payload.get("sampling_hz") or 100.0),
    )

    # Publish next stage
    bus_out.publish(
        {
            "type": "cleaned_ready",
            "upload_id": upload_id,
            "structure_id": payload.get("structure_id"),
            "structure_name": payload.get("structure_name"),
            "start_time": payload.get("start_time"),
            "end_time": payload.get("end_time"),
            "cleaned_fs": res.get("cleaned_fs"),
            "sensor_names": sensor_names,
        },
    )

    bus_in.ack(message_id)
    logger.info(f"Cleaning complete upload_id={upload_id}; published cleaned_ready")


if __name__ == "__main__":
    main()
