"""
EA-RAG 多专家融合器 v3

对应架构图中的 Multi-Expert Fusion 模块，负责:
- Result Aggregation (结果聚合)
- Conflict Detection (冲突检测)  
- Source Attribution (来源标注)
"""

from typing import List, Dict, Tuple
from .config import EXPERTS, CONFLICT_PAIRS
from .models import RetrievalResult, ExpertResult, CodeRequirement, StructuredResponse


class MultiExpertFusion:
    """
    多专家融合器
    
    对应架构图中的 Multi-Expert Fusion 框:
    - Result Aggregation: 聚合多个专家的检索结果
    - Conflict Detection: 检测不同专家结果之间的潜在冲突
    - Source Attribution: 为每个信息标注来源
    """
    
    def __init__(self):
        self.conflict_pairs = CONFLICT_PAIRS
    
    def fuse(self, retrieval_result: RetrievalResult) -> Dict:
        """
        融合多专家检索结果
        
        Args:
            retrieval_result: 多专家检索结果
            
        Returns:
            融合后的结果，包含:
            - aggregated_code: 聚合的规范内容
            - aggregated_experience: 聚合的经验内容
            - conflicts: 检测到的冲突
            - sources: 来源标注
        """
        # 1. Result Aggregation - 结果聚合
        aggregated = self._aggregate_results(retrieval_result)
        
        # 2. Conflict Detection - 冲突检测
        conflicts = self._detect_conflicts(retrieval_result)
        
        # 3. Source Attribution - 来源标注
        sources = self._attribute_sources(retrieval_result)
        
        return {
            "aggregated_code": aggregated["code"],
            "aggregated_experience": aggregated["experience"],
            "conflicts": conflicts,
            "sources": sources,
            "activated_experts": [
                {
                    "id": eid,
                    "name": EXPERTS[eid]["name"],
                    "icon": EXPERTS[eid]["icon"]
                }
                for eid in retrieval_result.activated_experts
                if eid in EXPERTS
            ]
        }
    
    def _aggregate_results(self, retrieval_result: RetrievalResult) -> Dict:
        """
        Result Aggregation - 聚合多专家结果
        
        按专家组织检索到的内容
        """
        aggregated_code = []
        aggregated_experience = []
        
        for expert_id in retrieval_result.activated_experts:
            if expert_id not in retrieval_result.expert_results:
                continue
            
            result = retrieval_result.expert_results[expert_id]
            config = EXPERTS.get(expert_id, {})
            
            # 聚合规范内容 (Code Layer)
            for chunk, score in result.code_chunks:
                aggregated_code.append({
                    "expert_id": expert_id,
                    "expert_name": config.get("name", expert_id),
                    "expert_icon": config.get("icon", "📄"),
                    "content": chunk.content,
                    "source": chunk.metadata.get("source", "未知"),
                    "clause": chunk.metadata.get("clause", ""),
                    "score": score
                })
            
            # 聚合经验内容 (Exp. Layer)
            for chunk, score in result.experience_chunks:
                aggregated_experience.append({
                    "expert_id": expert_id,
                    "expert_name": config.get("name", expert_id),
                    "expert_icon": config.get("icon", "📄"),
                    "content": chunk.content,
                    "source": chunk.metadata.get("source", "未知"),
                    "score": score
                })
        
        # 按相关度排序
        aggregated_code.sort(key=lambda x: x["score"], reverse=True)
        aggregated_experience.sort(key=lambda x: x["score"], reverse=True)
        
        return {
            "code": aggregated_code,
            "experience": aggregated_experience
        }
    
    def _detect_conflicts(self, retrieval_result: RetrievalResult) -> List[Dict]:
        """
        Conflict Detection - 检测专家间的潜在冲突
        
        对应架构图示例:
        - Fire 要求疏散走道 ≥1.4m
        - Accessibility 要求无障碍通道 ≥1.8m
        → 潜在冲突：仅满足消防要求可能忽视无障碍要求
        """
        conflicts = []
        activated = set(retrieval_result.activated_experts)
        
        # 检查预定义的冲突对
        for expert1, expert2, description in self.conflict_pairs:
            if expert1 in activated and expert2 in activated:
                # 两个可能冲突的专家都被激活
                result1 = retrieval_result.expert_results.get(expert1)
                result2 = retrieval_result.expert_results.get(expert2)
                
                if result1 and result2:
                    # 如果两个专家都有检索结果，标记潜在冲突
                    if result1.code_chunks and result2.code_chunks:
                        conflicts.append({
                            "type": "cross_domain",
                            "experts": [expert1, expert2],
                            "description": description,
                            "recommendation": f"请注意协调 {EXPERTS[expert1]['name']} 和 {EXPERTS[expert2]['name']} 的要求"
                        })
        
        return conflicts
    
    def _attribute_sources(self, retrieval_result: RetrievalResult) -> List[Dict]:
        """
        Source Attribution - 来源标注
        
        为每个检索结果标注清晰的来源信息
        """
        sources = []
        
        for expert_id in retrieval_result.activated_experts:
            if expert_id not in retrieval_result.expert_results:
                continue
            
            result = retrieval_result.expert_results[expert_id]
            config = EXPERTS.get(expert_id, {})
            
            # 规范来源
            for chunk, score in result.code_chunks:
                sources.append({
                    "expert_id": expert_id,
                    "expert": config.get("name", expert_id),
                    "icon": config.get("icon", "📄"),
                    "type": "code",
                    "source": chunk.metadata.get("source", "未知"),
                    "clause": chunk.metadata.get("clause", ""),
                    "score": score
                })
            
            # 经验来源
            for chunk, score in result.experience_chunks:
                sources.append({
                    "expert_id": expert_id,
                    "expert": config.get("name", expert_id),
                    "icon": config.get("icon", "📄"),
                    "type": "experience",
                    "source": chunk.metadata.get("source", "未知"),
                    "score": score
                })
        
        return sources
    
    def build_context_for_llm(self, fusion_result: Dict) -> str:
        """
        构建 LLM 上下文 - 将融合结果转换为提示词格式
        """
        lines = []
        
        # 激活的专家
        if fusion_result["activated_experts"]:
            expert_str = ", ".join([
                f"{e['icon']}{e['name']}" 
                for e in fusion_result["activated_experts"]
            ])
            lines.append(f"## 激活的专家: {expert_str}")
            lines.append("")
        
        # 规范内容 (Code Layer)
        lines.append("## 各专家规范条文 (Code Layer)")
        lines.append("")
        
        current_expert = None
        for item in fusion_result["aggregated_code"]:
            if item["expert_id"] != current_expert:
                current_expert = item["expert_id"]
                lines.append(f"### {item['expert_icon']} {item['expert_name']}")
            
            clause_info = f" {item['clause']}" if item['clause'] else ""
            lines.append(f"**来源: {item['source']}{clause_info}**")
            lines.append(item["content"])
            lines.append("")
        
        if not fusion_result["aggregated_code"]:
            lines.append("（未检索到相关规范条文）")
            lines.append("")
        
        # 经验内容 (Exp. Layer)
        lines.append("## 各专家经验解读 (Exp. Layer)")
        lines.append("")
        
        current_expert = None
        for item in fusion_result["aggregated_experience"]:
            if item["expert_id"] != current_expert:
                current_expert = item["expert_id"]
                lines.append(f"### {item['expert_icon']} {item['expert_name']}")
            
            lines.append(f"**来源: {item['source']}**")
            lines.append(item["content"])
            lines.append("")
        
        if not fusion_result["aggregated_experience"]:
            lines.append("（未检索到相关专家经验）")
            lines.append("")
        
        # 冲突提示
        if fusion_result["conflicts"]:
            lines.append("## ⚠️ 跨专业冲突提示")
            for conflict in fusion_result["conflicts"]:
                lines.append(f"- {conflict['description']}")
                lines.append(f"  建议: {conflict['recommendation']}")
            lines.append("")
        
        return "\n".join(lines)
