from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional

import json
import redis

from .settings import settings
from .logging import setup_logging

logger = setup_logging("shm.events")


@dataclass
class EventBus:
    """Redis Streams event bus for a single stage stream + consumer group.

    IMPORTANT:
    - Do NOT share one consumer group across *different* logical consumers.
    - This project uses *per-stage streams* to avoid fanout/ACK bugs.

    Each instance is bound to exactly one Redis stream and (optionally) one consumer group.
    """

    redis_url: str = settings.REDIS_URL
    stream: str = settings.STREAM_RAW_INGESTED
    group: Optional[str] = None

    def _client(self) -> redis.Redis:
        return redis.from_url(self.redis_url, decode_responses=True)

    def ensure_group(self) -> None:
        if not self.group:
            return
        r = self._client()
        try:
            r.xgroup_create(self.stream, self.group, id="0-0", mkstream=True)
            logger.info(f"Created consumer group '{self.group}' on stream '{self.stream}'")
        except redis.exceptions.ResponseError as e:
            if "BUSYGROUP" in str(e):
                return
            raise

    def publish(self, payload: Dict[str, Any]) -> str:
        """Publish a message to this stream.

        Payload is stored as JSON string under 'payload' plus a few optional top-level fields
        for easy filtering/queries.
        """
        r = self._client()
        # Store both: JSON payload + flattened strings for convenient access
        fields: Dict[str, str] = {
            "payload": json.dumps(payload, default=str),
        }
        # Common keys promoted (best-effort)
        for k in ("type", "upload_id", "structure_id", "structure_name", "start_time", "end_time", "filename"):
            if k in payload and payload[k] is not None:
                fields[k] = str(payload[k])
        return r.xadd(self.stream, fields)

    def consume_one(
        self,
        consumer: str,
        block_ms: int = 5000,
        read_count: int = 1,
    ) -> Optional[tuple[str, Dict[str, str]]]:
        """Return (message_id, fields) or None.

Fields include a JSON string in 'payload'.
        """
        if not self.group:
            raise RuntimeError("consume_one requires a consumer group")
        self.ensure_group()
        r = self._client()
        resp = r.xreadgroup(
            groupname=self.group,
            consumername=consumer,
            streams={self.stream: ">"},
            count=read_count,
            block=block_ms,
        )
        if not resp:
            return None
        _stream_name, messages = resp[0]
        if not messages:
            return None
        msg_id, fields = messages[0]
        return msg_id, fields

    def ack(self, message_id: str) -> None:
        if not self.group:
            return
        r = self._client()
        r.xack(self.stream, self.group, message_id)

    def claim_stale(self, consumer: str, min_idle_ms: int = 60000, count: int = 10) -> list[tuple[str, Dict[str, str]]]:
        """Claim stale pending messages to this consumer.

Returns list of (id, fields).
        """
        if not self.group:
            return []
        self.ensure_group()
        r = self._client()
        pending = r.xpending_range(self.stream, self.group, min="-", max="+", count=count)
        stale_ids = [p["message_id"] for p in pending if p.get("time_since_delivered", 0) >= min_idle_ms]
        if not stale_ids:
            return []
        claimed = r.xclaim(self.stream, self.group, consumername=consumer, min_idle_time=min_idle_ms, message_ids=stale_ids)
        return [(mid, fields) for mid, fields in claimed]


def parse_payload(fields: Dict[str, str]) -> Dict[str, Any]:
    """Parse the JSON payload stored in Redis stream fields.

Falls back to returning flattened fields if missing.
    """
    raw = fields.get("payload")
    if not raw:
        return dict(fields)
    try:
        return json.loads(raw)
    except Exception:
        return dict(fields)
