import json
import os
import numpy as np
import cv2
import ast
from tqdm import tqdm
import pandas as pd
from scipy.stats import gaussian_kde
from scipy.spatial.distance import pdist
from scipy.spatial import cKDTree
import random
from PIL import Image, PngImagePlugin


def load_cell_data(cell_config):
    fd_array = np.load(cell_config["fd_path"], allow_pickle=True)
    setup_name = f"Setup{cell_config['selected_setup']}"
    cluster_df = pd.read_excel(cell_config["clusters_path"], sheet_name=setup_name)
    with open(cell_config['centroids_path'], 'r') as f:
        centroids = json.load(f)
    with open(cell_config['mapping_path'], 'r') as f:
        id_to_patient = json.load(f)

    numeric_columns = cluster_df.columns.drop(['img_id', 'Cluster'])
    for col in numeric_columns:
        cluster_df[col] = cluster_df[col].astype(str).str.replace(",", ".").astype(float)

    filtered = cluster_df[cluster_df["Cluster"] == cell_config["selected_cluster"]]
    if filtered.empty:
        raise ValueError(f"No cluster {cell_config['selected_cluster']} found.")

    n = cell_config['selected_numcell']
    lower_bound = n * 0.8
    upper_bound = n * 1.4
    in_range = filtered[
        (filtered['n_cells'] >= lower_bound) &
        (filtered['n_cells'] <= upper_bound)
    ]
    closest_row = filtered.iloc[(filtered['n_cells'] - cell_config['selected_numcell']).abs().argmin()]
    if in_range.empty:
        selected_row = closest_row
        print("Closest image selected")
    else:
        print(f"Images found: {len(in_range)}")
        selected_row = in_range.sample(n=1).iloc[0]

    img_id = str(int(float(selected_row['img_id']))).strip()
    if img_id in id_to_patient:
        patient_name = id_to_patient[img_id]



    def iter_centroid_entries(obj):
        if isinstance(obj, dict):
            for v in obj.values():
                if isinstance(v, (list, tuple)):
                    for e in v:
                        yield e
                else:
                    yield v
        elif isinstance(obj, (list, tuple)):
            for item in obj:
                if isinstance(item, (list, tuple)):
                    for inner in item:
                        yield inner
                else:
                    yield item
        else:
            yield obj

    if patient_name in centroids:
        target_centroids = np.array([
            np.array(entry['Centroid'],dtype=float)
            for entry in centroids[patient_name]
        ])
    else:
        available = list(centroids.keys())
        raise ValueError(f"No centroids found for {patient_name}")
    print(f"Target image: {patient_name}")

    return fd_array, target_centroids, cell_config['selected_numcell'], patient_name

def filter(points, N, min_dist, max_trials=1000):
    points = np.asarray(points)
    for _ in range(max_trials):
        indices = np.random.permutation(len(points))
        selected = []
        for idx in indices:
            candidate = points[idx]
            if not selected:
                selected.append(candidate)
            else:
                tree = cKDTree(np.array(selected))
                dist, _ = tree.query(candidate, k=1)
                if dist >= min_dist:
                    selected.append(candidate)
            if len(selected) == N:
                return np.array(selected)
    raise RuntimeError(f"Failed to sample {N} points with min_dist={min_dist} after {max_trials} attempts.")

def get_kde_heatmap(points, x_grid, y_grid):
    kde = gaussian_kde(points.T, bw_method=0.05)
    positions = np.vstack([x_grid.ravel(), y_grid.ravel()])
    return kde(positions).reshape(x_grid.shape)

def reconstruct_contour(fd):
    fd = np.asarray(fd, dtype=np.complex128)
    contour = np.fft.ifft(fd)
    contour = np.stack((contour.real, contour.imag), axis=-1)
    contour = np.round(contour).astype(np.float32)
    return contour

def stratified_angles(n):
    sectors = np.linspace(0, 360, n+1)
    angles = [np.random.uniform(start, end) for start, end in zip(sectors[:-1],sectors[1:])]
    return angles

def rotate_contour(contour, angle_deg, center=None):
    if contour.shape[-1] != 2:
        raise ValueError("Contour should be a 2D array")
    if center is None:
        center = contour.mean(axis=0)
    angle_rad = np.deg2rad(angle_deg)
    cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
    R = np.array([[cos_a, -sin_a],
                  [sin_a, cos_a]])
    rotated = (contour - center) @ R.T + center
    return rotated

def placing_cell(contour, x_init, y_init, mask, canvas, color_map, cell_type, tissue_mask=None):
    center = contour.mean(axis=0)
    offset = np.array([x_init, y_init]) - center
    translated_contour = contour + offset
    translated_contour = translated_contour.astype(np.int32).reshape((-1,1,2))

    if tissue_mask is not None:
        temp_mask_tissue = np.zeros_like(tissue_mask)
        cv2.drawContours(temp_mask_tissue, [translated_contour], -1, 255, -1)
        total_pixels = np.count_nonzero(temp_mask_tissue)
        if total_pixels == 0:
            return False, mask, canvas
        inside_area = cv2.bitwise_and(temp_mask_tissue, tissue_mask)
        coverage = np.count_nonzero(inside_area) / np.count_nonzero(temp_mask_tissue)
        if coverage < 0.95:
            return False, mask, canvas

    temp_mask = np.zeros_like(mask)
    cv2.drawContours(temp_mask, [translated_contour], -1, (255,255,255), thickness=-1)
    new_contour_area = np.count_nonzero(temp_mask)
    if new_contour_area == 0:
        return False, mask, canvas

    intersection = cv2.bitwise_and(mask, temp_mask)
    overlap_area = np.count_nonzero(intersection)
    existing_area = np.count_nonzero(mask * (temp_mask > 0))

    if(
        overlap_area / new_contour_area >= 0.05 or
            (existing_area > 0 and overlap_area / existing_area >= 0.05)
        ):
        return False, mask, canvas

    cv2.drawContours(mask, [translated_contour], -1, color_map[cell_type], thickness=-1)
    cv2.fillPoly(canvas, [translated_contour], color_map[cell_type])
    cv2.polylines(canvas, [translated_contour], True, (0,51,102), thickness=1)

    return True, mask, canvas

def load_tissue_mask(tissue_json_path, desired_tissue_percentage, canvas_size=(1024,1024), tolerance=0.5):
    with open(tissue_json_path, "r") as f:
        tissue_data = json.load(f)

    candidates = [
        entry for entry in tissue_data
        if abs(entry.get("Percentage", 0) - desired_tissue_percentage) <= tolerance
    ]

    if not candidates:
        best_entry = min(
            tissue_data,
            key=lambda x: abs(x.get("Percentage", 0) - desired_tissue_percentage)
        )
        print(f"No entries in ±{tolerance}%. Selected:{best_entry['Image']} ({best_entry['Percentage']}%)")
    else:
        best_entry = random.choice(candidates)
        print(f"Among {len(candidates)}: {best_entry['Image']} ({best_entry['Percentage']}%)")

    tissue_mask = np.zeros(canvas_size, dtype=np.uint8)
    all_contours = best_entry["Coordinates"]
    for region in all_contours:
        contour_points = np.array(region, dtype=np.int32).reshape((-1,1,2))
        cv2.fillPoly(tissue_mask, [contour_points], 128)

    # print(f"Tissue selected {best_entry['Image']} ({best_entry['Percentage']}%)")
    return tissue_mask

def generate_cell_layout(cell_types, selected_cell_types, desired_tissue_percentage, output_path, tissue_json_path="tissue_contours.json", canvas_size=(1024,1024)):
    class_to_color = {
        "Fat": (255, 255, 255),
        "Other_Nuclei": (255, 0, 0),
        "Hepatocyte_Nuclei": (0, 0, 255),
        "Steatosis": (0, 255, 0),
    }

    limit_areas = {
        "Steatosis": 3.0,
        "Hepatocyte_Nuclei": 83.0,
        "Other_Nuclei": 7.5,
        "Fat": 1.0
    }

    if desired_tissue_percentage == 100:
        tissue_mask = None
        print("No tissue mask applied. (100% tissue coverage)")
        canvas = np.zeros((*canvas_size,3), dtype=np.uint8)
        canvas[:] = (128,128,128)
    else:
        tissue_mask = load_tissue_mask(tissue_json_path, desired_tissue_percentage)
        canvas = np.zeros((*canvas_size, 3), dtype=np.uint8)
        canvas[tissue_mask == 128] = (128,128,128)
    mask = np.zeros_like(canvas)
    reference_center = None

    classes_meta = []

    canvas_center = np.array(canvas_size) / 2

    for idx_type, cell_type in enumerate(selected_cell_types):
        print(f" === Processing {cell_type} ===")
        config = cell_types[cell_type]
        fd_array, target_centroids, N, patient_id = load_cell_data(config)

        # Applying the selected heatmap
        print(f"Number of cells: {len(target_centroids)}")
        data = target_centroids.T
        kde = gaussian_kde(data, bw_method=0.05)
        x_min, x_max = data[0].min() - 15, data[0].max() + 15
        y_min, y_max = data[1].min() - 15, data[1].max() + 15

        x_grid, y_grid = np.mgrid[x_min:x_max:512j, y_min:y_max:512j]
        positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

        target_density = kde(positions).reshape(x_grid.shape)
        min_dista = pdist(target_centroids).min()
        threshold = 1e-15
        max_iteractions = 5
        best_loss = np.inf
        best_points = None

        for i in range(max_iteractions):
            sampled_points = kde.resample(N* 1000).T
            np.random.shuffle(sampled_points)
            try:
                sampled_points = filter(sampled_points, N, min_dista // 3)
            except RuntimeError:
                continue

            gen_density = get_kde_heatmap(sampled_points, x_grid, y_grid)
            current_loss = np.mean((gen_density - target_density) ** 2)

            if current_loss < best_loss:
                best_loss = current_loss
                best_points = sampled_points.copy()

            if best_loss < threshold:
                break
        print(f"Best Loss: {best_loss * 1e10:.8f}")
        current_center = np.mean(best_points, axis=0)

        #Spatial Alignment
        if idx_type == 0:
            reference_center = current_center
        else:
            if reference_center is not None:
                shift = reference_center - current_center
                best_points += shift
        layout_center = np.mean(best_points, axis=0)
        canvas_shift = canvas_center - layout_center
        best_points += canvas_shift

        if tissue_mask is not None:
            h,w = tissue_mask.shape
            valid_points = []
            for (x,y) in best_points:
                xi, yi = int(round(x)), int(round(y))
                if 0 <= xi < w and 0 <= yi < h:
                    if tissue_mask[yi, xi] == 128:
                        valid_points.append((x,y))
            best_points = np.array(valid_points)
        print(f"Valid points inside tissue: {len(best_points)} / {N}")

        ## Using fd to create the mask
        plotted_cell = 0
        used_idx = set()
        angles_attempts = 4
        scales = [0.8, 0.6, 0.4, 0.2]
        h, w = canvas.shape[:2]
        base_step = int(min(h, w) * 0.005)
        step_factors = [1, 2, 3, 4, 5, 6]
        max_attempts_per_point = 10
        used_step = []

        for i, (x_init, y_init) in enumerate(best_points):
            available_idx = list(set(range(len(fd_array))) - used_idx)
            if not available_idx:
                break

            idx = random.choice(available_idx)
            fd_original = fd_array[idx]
            contour = reconstruct_contour(fd_original)
            angles = stratified_angles(angles_attempts)
            found = False

            for angle in angles:
                rotated_contour = rotate_contour(contour, angle)
                success, mask, canvas = placing_cell(rotated_contour, x_init, y_init, mask, canvas, class_to_color, cell_type, tissue_mask)
                if success:
                    used_idx.add(idx)
                    plotted_cell += 1
                    found = True
                    break

            if found:
                continue

            for scale in scales:
                fd_scaled = fd_original * scale
                contour_scaled =reconstruct_contour(fd_scaled)
                area = cv2.contourArea(contour_scaled.astype(np.float32))
                if area < limit_areas[cell_type]:
                    break

                success, mask, canvas = placing_cell(contour_scaled, x_init, y_init, mask, canvas, class_to_color, cell_type, tissue_mask)
                if success:
                    used_idx.add(idx)
                    plotted_cell += 1
                    found = True
                    break

                angles = stratified_angles(angles_attempts)

                for angle in angles:
                    rotated_contour = rotate_contour(contour_scaled, angle)
                    success, mask, canvas = placing_cell(contour_scaled, x_init, y_init, mask, canvas, class_to_color, cell_type, tissue_mask)
                    if success:
                        used_idx.add(idx)
                        plotted_cell += 1
                        found = True
                        break

                if found:
                    break

            if not found:
                for factor in step_factors:
                    step = base_step * factor
                    offsets = [
                        (step, 0), (-step, 0), (0, step), (0, -step),
                        (step, step), (-step, step), (step, -step), (-step, -step)
                    ]
                    for dx, dy in offsets:
                        x = x_init + dx
                        y = y_init + dy
                        success, mask, canvas = placing_cell(contour, x, y, mask, canvas, class_to_color, cell_type, tissue_mask)
                        if success:
                            used_idx.add(idx)
                            used_step.append(step)
                            plotted_cell += 1
                            found = True
                            break

                    if found:
                        break

            if not found:
                used_idx.add(idx)
                print(f"Cell {i} not inserted")

        print(f" {cell_type}: {plotted_cell} cells inserted.")

        classes_meta.append({
            "cell_type": str(cell_type),
            "mask_used": str(patient_id),
            "cell_requested": int(N),
            "cell_inserted": int(plotted_cell)
        })


    img = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
    metadata = {
        "tissue_percentage": float(desired_tissue_percentage),
        "classes": classes_meta
    }
    png_meta = PngImagePlugin.PngInfo()
    png_meta.add_text("generation_metadata", json.dumps(metadata))
    img.save(output_path, pnginfo=png_meta)
    print(f"Layout saved in {output_path}")
    return canvas, output_path

def generate_multiple_layouts(cell_types, selected_cell_types, desired_tissue_percentage,
                              output_dir='Testing', tissue_json_path="tissue_contours.json"):
    os.makedirs(output_dir, exist_ok=True)
    n_layouts = len(desired_tissue_percentage)

    for i in tqdm(range(n_layouts), desc="Generating layouts"):
        print(f"==== GENERATING LAYOUT {i+1}/{n_layouts} ====")
        current_cell_types = {}
        for ctype, cfg in cell_types.items():
            current_cell_types[ctype] = {
                "centroids_path": cfg["centroids_path"],
                "clusters_path": cfg["clusters_path"],
                "fd_path": cfg["fd_path"],
                "mapping_path": cfg["mapping_path"],
                "selected_cluster": cfg["selected_cluster"][i],
                "selected_numcell": cfg["selected_numcell"][i],
                "selected_setup": cfg["selected_setup"][i],
            }
        base_name = "Test"
        counter = 1
        while True:
            output_path = os.path.join(output_dir, f"{base_name}{counter}.png")
            if not os.path.exists(output_path):
                break
            counter += 1

        tissue_percentage = desired_tissue_percentage[i]

        canvas, path = generate_cell_layout(
            current_cell_types,
            selected_cell_types,
            desired_tissue_percentage=tissue_percentage,
            output_path=output_path,
            tissue_json_path=tissue_json_path
        )

        print(f" Test {counter} saved in: {output_path}")

folder = os.getcwd()
cell_types = {
    "Steatosis": {
        "centroids_path": os.path.join(folder, "All_centroids/Steatosis_centroids.json"),
        "clusters_path": os.path.join(folder, "Cluster/Steatosis_clusters.xlsx"),
        "fd_path": os.path.join(folder, "Output_fd/fourier_descriptors_steatosis.npy"),
        "mapping_path": os.path.join(folder, "All_centroids/Steatosis_key_mapping.json"),
        "selected_cluster": np.full(10, 0),
        "selected_numcell": np.random.randint(50,130, size=10),
        "selected_setup": np.full(10, 1)
    },
    "Hepatocyte_Nuclei": {
        "centroids_path": os.path.join(folder, "All_centroids/Hepatocyte_Nuclei_centroids.json"),
        "clusters_path": os.path.join(folder, "Cluster/Hepatocyte_Nuclei_clusters.xlsx"),
        "fd_path": os.path.join(folder, "Output_fd/fourier_descriptors_hepatocytenuclei.npy"),
        "mapping_path": os.path.join(folder, "All_centroids/Hepatocyte_Nuclei_key_mapping.json"),
        "selected_cluster": np.full(10, 2),
        "selected_numcell": np.full(10, 120),
        "selected_setup": np.full(10, 1)
    },
    "Other_Nuclei": {
        "centroids_path": os.path.join(folder, "All_centroids/Other_Nuclei_centroids.json"),
        "clusters_path": os.path.join(folder, "Cluster/Other_Nuclei_clusters.xlsx"),
        "fd_path": os.path.join(folder, "Output_fd/fourier_descriptors_othernuclei.npy"),
        "mapping_path": os.path.join(folder, "All_centroids/Other_Nuclei_key_mapping.json"),
        "selected_cluster": np.full(10, 2),
        "selected_numcell": np.full(10, 200),
        "selected_setup": np.full(10, 1)
    },
    "Fat": {
        "centroids_path": os.path.join(folder, "All_centroids/Fat_centroids.json"),
        "clusters_path": os.path.join(folder, "Cluster/Fat_clusters.xlsx"),
        "fd_path": os.path.join(folder, "Output_fd/fourier_descriptors_fat.npy"),
        "mapping_path": os.path.join(folder, "All_centroids/Fat_key_mapping.json"),
        "selected_cluster": np.full(10, 1),
        "selected_numcell": np.full(10, 30),
        "selected_setup": np.full(10, 1)
    },
}
selected_cell_types = ["Steatosis", "Other_Nuclei", "Hepatocyte_Nuclei", "Fat"]

generate_multiple_layouts(
    cell_types,
    selected_cell_types,
    desired_tissue_percentage=np.full(10, 100),
    output_dir="Testing",
    tissue_json_path=os.path.join(folder, "tissue_contours.json")
)