"""
EA-RAG 专家模块 - 多专家核心逻辑 v3

对应架构图中的:
- Fire Safety Expert (Code Layer + Exp. Layer)
- Accessibility Expert (Code Layer + Exp. Layer)
- Energy Expert (Code Layer + Exp. Layer)
- Other Experts (Code Layer + Exp. Layer)

每个专家包含:
- Code Layer: 规范原文 (如 GB 50016, GB 55037)
- Exp. Layer:
    - Clause Interpretation (条款解读)
    - Risk Tips (风险提示)
    - Case Experience (案例经验)
"""

import re
from pathlib import Path
from typing import List, Dict, Optional

from .config import (
    KNOWLEDGE_BASE_DIR, EXPERTS,
    CHUNK_SIZE, TOP_K_PER_EXPERT, KEYWORD_MATCH_THRESHOLD,
    CODE_MATCH_BOOST, MAX_ACTIVATED_EXPERTS
)
from .models import Chunk, ExpertResult


class Expert:
    """
    单个领域专家

    对应架构图中的一个 Expert 框，包含:
    - Code Layer: code/ 目录下的规范文档
    - Exp. Layer: experience/ 目录下的经验文档

    使用 ChromaDB 存储向量:
    - Collection: {expert_id}_code
    - Collection: {expert_id}_experience
    """

    def __init__(self, expert_id: str, config: Dict, embedding_service, vector_store):
        self.expert_id = expert_id
        self.name = config["name"]
        self.name_en = config.get("name_en", self.name)
        self.icon = config["icon"]
        self.keywords = config["keywords"]
        self.codes = config.get("codes", [])  # 相关规范代码
        self.description = config["description"]
        self.embedding_service = embedding_service
        self.vector_store = vector_store

        # Collection 名称
        self.code_collection = f"{expert_id}_code"
        self.experience_collection = f"{expert_id}_experience"

        # 目录
        self.base_dir = KNOWLEDGE_BASE_DIR / expert_id
        self.code_dir = self.base_dir / "code"
        self.experience_dir = self.base_dir / "experience"

    def setup_directories(self):
        """创建专家目录"""
        self.code_dir.mkdir(parents=True, exist_ok=True)
        self.experience_dir.mkdir(parents=True, exist_ok=True)

    def match_query(self, query: str) -> float:
        """
        计算查询与专家的匹配分数

        对应架构图中 Query Router 的 Identify Relevant Domains 逻辑

        Returns:
            匹配分数 (0.0 - 1.0+)
        """
        query_lower = query.lower()
        score = 0.0

        # 1. 关键词匹配
        keyword_matches = sum(1 for kw in self.keywords if kw in query_lower)
        score += keyword_matches * 0.1

        # 2. 规范代码匹配 (权重更高)
        # 支持多种格式: "GB 50016", "GB50016", "GB-50016"
        for code in self.codes:
            patterns = [
                code,  # GB 50016
                code.replace(" ", ""),  # GB50016
                code.replace(" ", "-"),  # GB-50016
            ]
            for pattern in patterns:
                if pattern.lower() in query_lower or pattern in query:
                    score += CODE_MATCH_BOOST * 0.1
                    break

        # 3. 专家名称匹配
        if self.name in query or self.name_en.lower() in query_lower:
            score += 0.2

        return score

    def index_exists(self) -> bool:
        """检查索引是否已存在"""
        code_exists = self.vector_store.collection_exists(self.code_collection)
        exp_exists = self.vector_store.collection_exists(self.experience_collection)
        return code_exists or exp_exists

    def build_index(self, force_rebuild: bool = False):
        """
        构建索引 - 处理 Code Layer 和 Exp. Layer

        Args:
            force_rebuild: 是否强制重建（删除现有数据）
        """
        print(f"   {self.icon} 构建 {self.name} 索引...")
        print(f"      Code Layer: {self.code_dir}")
        print(f"      Exp. Layer: {self.experience_dir}")

        if force_rebuild:
            # 删除现有 Collections
            self.vector_store.delete_collection(self.code_collection)
            self.vector_store.delete_collection(self.experience_collection)

        # 处理 Code Layer (规范文档)
        code_chunks = self._process_directory(self.code_dir, "code")
        if code_chunks:
            print(f"      生成 Code Layer {len(code_chunks)} 个分块的嵌入向量...")
            code_embeddings = self.embedding_service.get_embeddings_batch(
                [c.content for c in code_chunks]
            )
            self.vector_store.add_chunks(self.code_collection, code_chunks, code_embeddings)

        # 处理 Exp. Layer (经验文档)
        exp_chunks = self._process_directory(self.experience_dir, "experience")
        if exp_chunks:
            print(f"      生成 Exp Layer {len(exp_chunks)} 个分块的嵌入向量...")
            exp_embeddings = self.embedding_service.get_embeddings_batch(
                [c.content for c in exp_chunks]
            )
            self.vector_store.add_chunks(self.experience_collection, exp_chunks, exp_embeddings)

        print(f"      ✓ Code: {len(code_chunks)}, Exp: {len(exp_chunks)}")

    def _process_directory(self, directory: Path, doc_type: str) -> List[Chunk]:
        """处理目录中的文档"""
        chunks = []

        if not directory.exists():
            return chunks

        files = [f for f in directory.glob("*")
                 if f.is_file() and not f.name.startswith('.')]

        for file_path in files:
            content = self._read_file(file_path)
            if content:
                file_chunks = self._chunk_text(
                    content,
                    {"source": file_path.name, "path": str(file_path)},
                    doc_type
                )
                chunks.extend(file_chunks)

        return chunks

    def _read_file(self, file_path: Path) -> str:
        """读取文件"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        except:
            return ""

    def _chunk_text(self, text: str, metadata: Dict, doc_type: str) -> List[Chunk]:
        """
        智能分块 - 优先按章节分块
        """
        chunks = []

        # 尝试按章节分块 (适用于规范文档)
        if doc_type == "code":
            section_chunks = self._chunk_by_sections(text, metadata, doc_type)
            if section_chunks:
                return section_chunks

        # 按段落分块
        paragraphs = text.split('\n\n')
        current = ""

        for para in paragraphs:
            para = para.strip()
            if not para:
                continue

            if len(current) + len(para) < CHUNK_SIZE:
                current += para + "\n\n"
            else:
                if current:
                    chunks.append(Chunk(
                        content=current.strip(),
                        metadata=metadata.copy(),
                        doc_type=doc_type,
                        expert_id=self.expert_id
                    ))
                current = para + "\n\n"

        if current.strip():
            chunks.append(Chunk(
                content=current.strip(),
                metadata=metadata.copy(),
                doc_type=doc_type,
                expert_id=self.expert_id
            ))

        return chunks

    def _chunk_by_sections(self, text: str, metadata: Dict, doc_type: str) -> List[Chunk]:
        """
        按章节分块 - 识别规范条款结构

        识别模式:
        - ### 5.5.17 条款内容
        - 5.5.17 条款内容
        - 第5.5.17条 条款内容
        """
        chunks = []

        # 正则匹配章节
        section_pattern = re.compile(
            r'^(?:#{1,6}\s*)?(?:第)?(\d+(?:\.\d+)*)\s*(?:条)?\s*(.*)$',
            re.MULTILINE
        )

        matches = list(section_pattern.finditer(text))

        if len(matches) < 3:  # 如果章节太少，不使用章节分块
            return []

        for i, match in enumerate(matches):
            start = match.start()
            end = matches[i + 1].start() if i + 1 < len(matches) else len(text)

            section_num = match.group(1)
            section_content = text[start:end].strip()

            if len(section_content) > 30:
                chunk_metadata = metadata.copy()
                chunk_metadata["clause"] = f"§{section_num}"

                chunks.append(Chunk(
                    content=section_content,
                    metadata=chunk_metadata,
                    doc_type=doc_type,
                    expert_id=self.expert_id
                ))

        return chunks

    def search(self, query_embedding: List[float], activation_score: float = 0.0) -> ExpertResult:
        """
        检索 - 在 Code Layer 和 Exp. Layer 中分别检索（使用 ChromaDB）
        """
        # 从 ChromaDB 搜索
        code_results = self.vector_store.search(
            self.code_collection,
            query_embedding,
            TOP_K_PER_EXPERT
        )

        experience_results = self.vector_store.search(
            self.experience_collection,
            query_embedding,
            TOP_K_PER_EXPERT
        )

        return ExpertResult(
            expert_id=self.expert_id,
            expert_name=self.name,
            expert_icon=self.icon,
            code_chunks=code_results,
            experience_chunks=experience_results,
            activation_score=activation_score
        )

    def get_stats(self) -> Dict:
        """获取统计信息"""
        return {
            "expert_id": self.expert_id,
            "name": self.name,
            "name_en": self.name_en,
            "icon": self.icon,
            "codes": self.codes,
            "code_chunks": self.vector_store.get_collection_count(self.code_collection),
            "experience_chunks": self.vector_store.get_collection_count(self.experience_collection)
        }


class QueryRouter:
    """
    专家路由器 - 对应架构图中的 Query Router

    功能: Identify Relevant Domains
    决定哪些专家被激活 (Activated)，哪些不激活 (Not Activated)
    """

    def __init__(self, experts: Dict[str, Expert]):
        self.experts = experts

    def route(self, query: str) -> Dict:
        """
        路由查询到相关专家

        Returns:
            {
                "activated": [(expert_id, score), ...],    # 激活的专家
                "not_activated": [expert_id, ...]          # 未激活的专家
            }
        """
        scores = []

        # 计算每个专家的匹配分数
        for expert_id, expert in self.experts.items():
            score = expert.match_query(query)
            scores.append((expert_id, score))

        # 按分数排序
        scores.sort(key=lambda x: x[1], reverse=True)

        # 确定激活的专家
        activated = []
        for expert_id, score in scores:
            if score >= KEYWORD_MATCH_THRESHOLD * 0.1:
                activated.append((expert_id, score))
                if len(activated) >= MAX_ACTIVATED_EXPERTS:
                    break

        # 如果没有激活任何专家，激活 general
        if not activated:
            if "general" in self.experts:
                activated.append(("general", 0.0))
            else:
                # 激活分数最高的
                activated.append(scores[0])

        # 未激活的专家
        activated_ids = {e[0] for e in activated}
        not_activated = [eid for eid, _ in scores if eid not in activated_ids]

        return {
            "activated": activated,
            "not_activated": not_activated
        }

    def explain_routing(self, query: str) -> str:
        """解释路由决策 (用于调试)"""
        result = self.route(query)

        lines = [
            f"Query: {query}",
            "",
            "Activated:",
        ]

        for expert_id, score in result["activated"]:
            expert = self.experts[expert_id]
            lines.append(f"  ✅ {expert.icon} {expert.name} (score: {score:.2f})")

        lines.append("")
        lines.append("Not Activated:")

        for expert_id in result["not_activated"]:
            expert = self.experts[expert_id]
            lines.append(f"  ⬜ {expert.icon} {expert.name}")

        return "\n".join(lines)