from __future__ import annotations

import json
from typing import Dict, List, Set

from ..model import Attribute, DiagramElement, Entity, Input, Mechanic
from ..model.sheep_wolf_diagram import SHEEP_WOLF_DIAGRAM
from .preprocess import preprocess_diagram


def dfs_to_target(node: DiagramElement, path: List[DiagramElement], visited: Set[str], paths: List[List[DiagramElement]]) -> None:
    if isinstance(node, Entity) and node.is_target:
        paths.append(path + [node])
        return
    
    if isinstance(node, Attribute):
        parent = getattr(node, "parent", None)
        if parent and getattr(parent, "is_target", False):
            paths.append(path + [node, parent])
            return
    
    for link in getattr(node, "links", []):
        target = getattr(link, "target", None)
        if target is None or target.id in visited:
            continue
        
        visited.add(target.id)
        dfs_to_target(target, path + [node], visited, paths)
        visited.remove(target.id)


def calculate_mltt() -> Dict[str, float]:
    diagram = preprocess_diagram(SHEEP_WOLF_DIAGRAM)
    paths = []
    
    for node in diagram.nodes.values():
        if isinstance(node, Input):
            visited: Set[str] = {node.id}
            dfs_to_target(node, [], visited, paths)
    
    results = {}
    
    for path in paths:
        input_node = path[0]
        target_entity = path[-1]
        
        dof = getattr(input_node, "degree_of_freedom", 1.0)
        
        mechanics = [n for n in path[1:-1] if isinstance(n, Mechanic)]
        num_mechanics = len(mechanics)
        
        visible_attrs = sum(1 for link in getattr(target_entity, "links", [])
                           if isinstance(getattr(link, "target", None), Attribute) 
                           and getattr(getattr(link, "target", None), "visible", True))
        
        last_mechanic_link = None
        if mechanics:
            last_mech = mechanics[-1]
            for link in getattr(last_mech, "links", []):
                if getattr(link, "target", None) in path:
                    last_mechanic_link = link
        
        granularity = abs(getattr(last_mechanic_link, "granularity", 1.0)) if last_mechanic_link else 1.0
        order = abs(getattr(target_entity, "magnitude_order", 1.0))
        
        base = dof * num_mechanics * visible_attrs
        exponent = order / granularity if granularity != 0 else order
        mltt_value = base ** exponent
        
        if mechanics:
            path_key = f"{input_node.id} -> {mechanics[-1].id} -> {target_entity.id}"
        else:
            path_key = f"{input_node.id} -> {target_entity.id}"
        results[path_key] = mltt_value
    
    return results


if __name__ == "__main__":
    results = calculate_mltt()
    print(json.dumps(results, indent=2))

