import json
import ijson
import os
import gc
import tkinter as tk
from tkinter import filedialog

from sklearn.cluster import AgglomerativeClustering
from tqdm import tqdm
import pandas as pd
import re
import umap
from collections import defaultdict
from scipy.fft import fft
import numpy as np
from scipy.interpolate import interp1d
from shapely.geometry import Polygon, box
from scipy.spatial import KDTree, Voronoi
from scipy.stats import gaussian_kde, entropy
import hdbscan
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

def stream_json_contours(file_path, target_class):
    with open(file_path, 'r') as f:
        for item in ijson.items(f, f'{target_class}.item'):
            yield item


## Extracting fourier descriptors from json
folder = os.getcwd()
file_path = os.path.join(folder, "contours.json")
data_path = file_path if file_path else None

print(" --- Extracting Fourier Descriptors ---")
print(f"Selected file: ", data_path)
output_path = "Output_fd"
os.makedirs(os.path.join(folder, output_path), exist_ok=True)
output_fd_path_tss = os.path.join(folder, output_path, 'fourier_descriptors_tissue.npy')
output_fd_path_fat = os.path.join(folder, output_path, 'fourier_descriptors_fat.npy')
output_fd_path_hptncl = os.path.join(folder, output_path, 'fourier_descriptors_hepatocytenuclei.npy')
output_fd_path_othncl = os.path.join(folder, output_path, 'fourier_descriptors_othernuclei.npy')
output_fd_path_ste = os.path.join(folder, output_path, 'fourier_descriptors_steatosis.npy')

def remove_duplicate_points(coords): ## filtering duplicates or too closed points
    diffs = np.diff(coords, axis=0)
    dists = np.hypot(diffs[:,0], diffs[:,1])
    keep = np.concatenate([[True], dists > 1e-6])
    return coords[keep]

def resample_contour(coords, num_points): ## resample the cellular contour in a fixed number of points evenly distributed
    coords = remove_duplicate_points(coords)
    if len(coords) < 2:
        raise ValueError("Contour too small for resampling")

    deltas = np.diff(coords, axis=0)
    dists = np.hypot(deltas[:,0], deltas[:,1])
    cumulative_dists = np.concatenate([[0], np.cumsum(dists)])

    total_length = cumulative_dists[-1]
    desired_dists = np.linspace(0, total_length, num_points)

    fx = interp1d(cumulative_dists, coords[:,0], kind = 'linear')
    fy = interp1d(cumulative_dists, coords[:,1], kind = 'linear')

    resample_coords = np.stack((fx(desired_dists), fy(desired_dists)), axis=-1)
    return resample_coords

def normalize_contour_fd(coords): ## contours become x+j*y to be transformed with fourier
    complex_contour = coords[:,0] + 1j * coords[:,1]
    complex_contour -= complex_contour.mean()
    return complex_contour

def is_contour_inside_image(coords, image_size=1024, margin=0):
    x_min, y_min = coords.min(axis=0)
    x_max, y_max = coords.max(axis=0)

    if (
        x_min <= margin or y_min <= margin or
        x_max >= image_size - 1 - margin or
        y_max >= image_size - 1 - margin
    ):
        return False
    return True

def calculate_fd(contours_iter, output_path, class_name, min_len, n_resample, batch_size=500, apply_inside_filter=True):
    temp_files = []
    fd_batch = []
    print(f"Calculating Fourier Descriptors [{class_name}]...")
    count = 0
    for i, cell in enumerate(contours_iter):
        coords = cell["Coordinates"]
        if len(coords) < min_len:
            continue
        try:
            coords_np = np.array(coords, dtype=np.float32).reshape(-1,2)
            if apply_inside_filter and not is_contour_inside_image(coords_np, image_size=1024, margin=0):
                continue

            coords_resampled = resample_contour(coords_np, n_resample)
            complex_contour = normalize_contour_fd(coords_resampled)
            fd = fft(complex_contour).astype(np.complex64)
            fd_batch.append(fd)
        except Exception:
            continue

        if len(fd_batch) >= batch_size:
            batch_arr = np.stack(fd_batch)
            temp_name = f"{output_path}_part{len(temp_files)}.npy"
            np.save(temp_name, batch_arr)
            temp_files.append(temp_name)
            count += len(fd_batch)
            fd_batch.clear()
            gc.collect()
            print(f"Saved batch {len(temp_files)} ({count} total)")
    if fd_batch:
        batch_arr = np.stack(fd_batch)
        temp_name = f"{output_path}_part{len(temp_files)}.npy"
        np.save(temp_name, batch_arr)
        temp_files.append(temp_name)

    all_batches = [np.load(f) for f in temp_files]
    all_fd = np.concatenate(all_batches, axis=0)
    np.save(output_path, all_fd)
    for f in temp_files:
        os.remove(f)
    print(f"Saved total {len(all_fd)} descriptors to {output_path}")


classes= {
    "Hepatocyte_Nuclei": (output_fd_path_hptncl, 0, 49),
    "Other_Nuclei": (output_fd_path_othncl, 0, 27),
    "Fat": (output_fd_path_fat, 0, 31),
    "Steatosis": (output_fd_path_ste, 0, 46),
}

for class_name, (output, min_len, n_resample) in classes.items():
    contours_iter = stream_json_contours(data_path, class_name)
    calculate_fd(contours_iter, output, class_name, min_len, n_resample, apply_inside_filter=True)

print("All Fourier Descriptors calculated successfully!!")
# ## Analysis of Spatial Disposition
#

## Extracting centroids of all cells
print("--- Extracting Centroids ---")
output_path = os.path.join(folder, "All_centroids")
os.makedirs(output_path, exist_ok=True)

def _append_to_json(out_path, data):
    if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
        with open(out_path, 'r', encoding='utf-8') as f:
            existing = json.load(f)
    else:
        existing = {}

    for patient, entries in data.items():
        existing.setdefault(patient, []).extend(entries)

    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(existing, f, ensure_ascii=False, separators=(",", ":"))

def calculate_centroids_for_class(file_path, output_path, class_name, batch_size=2000):
    out_path = os.path.join(output_path, f"{class_name}_centroids.json")

    patients_dict = {}
    contours_iter = stream_json_contours(file_path, class_name)

    for cell in tqdm(contours_iter, desc=f"{class_name}", unit="contours"):
        coords = cell.get("Coordinates")
        patient = cell.get("Patient")
        try:
            poly = Polygon(coords)
            if not poly.is_valid or poly.is_empty:
                continue
            c = poly.centroid
            centroid_entry = {
                "Centroid": [c.x, c.y],
                "Id": cell.get("Id")
            }
            patients_dict.setdefault(patient, []).append(centroid_entry)
        except Exception:
            continue

        if sum(len(v) for v in patients_dict.values()) >= batch_size:
            _append_to_json(out_path, patients_dict)
            patients_dict.clear()
            gc.collect()

    if patients_dict:
        _append_to_json(out_path, patients_dict)

    print(f"Centroids for {class_name} saved to {out_path}")

classes = [
    "Hepatocyte_Nuclei",
    "Other_Nuclei",
    "Fat",
    "Steatosis"
]

for class_name in classes:
    calculate_centroids_for_class(data_path, output_path, class_name)

print("All centroids extracted successfully!!")

## Extracting features of all images

print("--- Extracting features ---")

def voronoi_finite_polygons_2d(vor, radius=None):
    if vor.points.shape[1] != 2:
        raise ValueError("Only supports 2D.")

    new_regions = []
    new_vertices = vor.vertices.tolist()
    center = vor.points.mean(axis=0)
    if radius is None:
        radius = np.ptp(vor.points, axis=0).max() * 2

    all_ridges = {}
    for (p1,p2), (v1,v2) in zip(vor.ridge_points, vor.ridge_vertices):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))

    for p1, region_index in enumerate(vor.point_region):
        region = vor.regions[region_index]
        if -1 not in region and len(region)>0:
            new_regions.append(region)
            continue

        ridges = all_ridges[p1]
        new_region = [v for v in region if v != -1]

        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0 and v2 >= 0:
                continue

            t = vor.points[p2] - vor.points[p1]
            t = np.array([-t[1], t[0]])
            t /= np.linalg.norm(t)
            midpoint = vor.points[[p1,p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, t)) * t
            far_point = vor.vertices[v2] + direction * radius
            new_vertices.append(far_point.tolist())
            new_region.append(len(new_vertices) - 1)

        new_regions.append(new_region)
    return new_regions, np.array(new_vertices)
#
def get_voronoi_areas(points):
    vor = Voronoi(points)
    regions, vertices = voronoi_finite_polygons_2d(vor)

    min_x, min_y = np.min(points, axis=0)
    max_x, max_y = np.max(points, axis=0)
    dx = max_x - min_x
    dy = max_y - min_y
    pad_x = dx * 0.1
    pad_y = dy * 0.1
    bbox = box(min_x - pad_x, min_y - pad_y,
               max_x + pad_x, max_y + pad_y)

    areas = []
    for region in regions:
        polygon = Polygon(vertices[region])
        if not polygon.is_valid:
            continue
        clipped = polygon.intersection(bbox)
        if clipped.is_empty:
            continue
        areas.append(clipped.area)
    return np.array(areas)
#
def compute_density_entropy(points, grid_size=100):
    x, y = np.array(points).T
    kde = gaussian_kde([x, y])
    xi, yi = np.meshgrid(
        np.linspace(x.min(), x.max(), grid_size),
        np.linspace(y.min(), y.max(), grid_size)
    )
    zi = kde(np.vstack([xi.ravel(), yi.ravel()])).reshape(xi.shape)
    p = zi / zi.sum()
    return entropy(p.ravel()), p

def compute_fourier_features(density_map, n_features=5):
    f_transform = np.fft.fft2(density_map)
    f_shifted = np.fft.fftshift(f_transform)
    magnitude = np.abs(f_shifted)

    h,w = magnitude.shape
    half_mag = magnitude[h //2:, w //2:]
    flat = half_mag.ravel()
    flat /= flat.sum()

    top_indices = np.argsort(flat)[-n_features:]
    coords = [np.unravel_index(idx, half_mag.shape) for idx in top_indices]
    radial_freqs = [np.sqrt(y ** 2 + x ** 2) for y, x in coords]
    sorted_data = sorted(zip(radial_freqs, top_indices))
    sorted_indices = [idx for _, idx in sorted_data]
    sorted_values = flat[sorted_indices]

    return sorted_values
#
def extract_features_from_centroids(centroid_list, k=5):
    coords = np.array([c['Centroid'] for c in centroid_list])
    kdtree = KDTree(coords)
    dists, _ = kdtree.query(coords, k=k+1)
    mean_knn_dist = dists[:,1:].mean()

    voronoi_areas = get_voronoi_areas(coords)
    voronoi_var = np.var(voronoi_areas) if len(voronoi_areas)>0 else 0

    density_entropy, density_map = compute_density_entropy(coords)
    fourier_feats = compute_fourier_features(density_map)

    return {
        "mean_knn_dist": mean_knn_dist,
        "voronoi_var": voronoi_var,
        "density_entropy": density_entropy,
        **{f"fft{i+1}": val for i, val in enumerate(fourier_feats)}
    }
#
def process_cell_class(folder, class_name,):
    centroids_path = os.path.join(folder, f"All_centroids/{class_name}_centroids.json")
    print(f"Opening {class_name} centroids file...")
    with open(centroids_path, encoding='utf-8') as f:
        data = json.load(f)

    mapping = {}
    remapped_data = {}
    for i, (key, value) in enumerate(data.items(), start=1):
        mapping[str(i)] = key
        remapped_data[str(i)] = value


    features = []
    for img_id, centroid_list in tqdm(remapped_data.items(), desc=f"Processing {class_name}", unit='image'):
        if len(centroid_list) < 6:
            continue
        feats = extract_features_from_centroids(centroid_list)
        feats['img_id'] = img_id
        feats['n_cells'] = len(centroid_list)
        features.append(feats)

    df = pd.DataFrame(features)
    df['img_id'] = df['img_id'].astype(str)

    if "voronoi_var" in df.columns:
        df['voronoi_var'] = df['voronoi_var'].apply(lambda x: np.log1p(x))

    out_path = os.path.join(folder, f"All_centroids/{class_name}_features_extracted.xlsx")
    out_mapping = os.path.join(folder, f"All_centroids/{class_name}_key_mapping.json")
    df.to_excel(out_path, index=False)
    with open(out_mapping, "w", encoding='utf-8') as f:
        json.dump(mapping, f, ensure_ascii=False, separators=(",", ":"))
    print(f"Saved: {out_path}")



## Extracting features of all images
cell_classes = ["Fat", 'Hepatocyte_Nuclei', "Other_Nuclei", "Steatosis"]
for cell_type in cell_classes:
    process_cell_class(folder, cell_type)

## CLUSTERING
print("--- Clustering ---")

def clustering_func2(features_path, output_file, setups, use_umap=True, n_neighbors=80, min_dist=0.4):
    df_base = pd.read_excel(features_path)

    feats_only = df_base.drop(columns=['img_id', 'n_cells'])
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(feats_only)

    if use_umap:
        reducer = umap.UMAP(
            n_neighbors=n_neighbors,
            n_components=5,
            min_dist=min_dist,
            metric='euclidean',
            random_state=42
        )
        X_reduced = reducer.fit_transform(X_scaled)
    else:
        X_reduced = X_scaled

    metrics_list= []

    with pd.ExcelWriter(output_file, engine="openpyxl") as writer:
        for setup in setups:
            setup_name = setup["name"]
            model = AgglomerativeClustering(n_clusters=setup['num_cluster'],linkage='ward')
            agglo_labels = model.fit_predict(X_reduced)
            sil = silhouette_score(X_reduced, agglo_labels, metric='euclidean')
            df = df_base.copy()
            df['Cluster'] = agglo_labels
            df.to_excel(writer, sheet_name=setup_name, index=False)

            n_clusters = len(set(agglo_labels))
            cluster_counts = pd.Series(agglo_labels).value_counts().sort_index().tolist()
            counts_str = " - ".join(map(str, cluster_counts))
            metrics_list.append({
                "Setup": setup_name,
                "n_clusters": n_clusters,
                "silhouette": sil,
                "count": counts_str
            })
            print(f"Clustering {setup_name} completed.")
        metrics_df = pd.DataFrame(metrics_list)
        metrics_df.to_excel(writer, sheet_name="Metrics", index=False)


def clustering_all_classes2(folder, setups):
    input_files = {
        "Fat": os.path.join(folder, "All_centroids/Fat_features_extracted.xlsx"),
        "Other_Nuclei": os.path.join(folder, "All_centroids/Other_Nuclei_features_extracted.xlsx"),
        "Hepatocyte_Nuclei": os.path.join(folder, "All_centroids/Hepatocyte_Nuclei_features_extracted.xlsx"),
        "Steatosis": os.path.join(folder, "All_centroids/Steatosis_features_extracted.xlsx"),
    }
    cluster_output = os.path.join(folder, "Cluster")
    os.makedirs(cluster_output, exist_ok=True)

    for cell_type, features_path in input_files.items():
        output_file = os.path.join(cluster_output, f"{cell_type}_clusters.xlsx")
        print(f"\n >>> Processing {cell_type} ...")
        clustering_func2(features_path, output_file, setups)

setups = [
    {"name": "Setup1", "num_cluster": 3},
    {"name": "Setup2", "num_cluster": 4},
    {"name": "Setup3", "num_cluster": 5},
    {"name": "Setup4", "num_cluster": 6},
]

clustering_all_classes2(folder, setups)