import cv2
import shutil
from pathlib import Path
import numpy as np
import json
import os
from tqdm import tqdm

folder = os.path.dirname(os.getcwd())


color_to_class = {
    (0, 255, 0): "Steatosis",
    (128, 128, 128): "Tissue",
    (0,0,255): "Hepatocyte_Nuclei",
    # (0,51,102): "Borders",
    (255,0,0): "Other_Nuclei",
    (255,255,255): "Fat"
}

class_to_color = {
    "Tissue": (128, 128, 128),
    "Fat": (255, 255, 255),
    "Other_Nuclei": (255, 0, 0),
    "Hepatocyte_Nuclei": (0, 0, 255),
    "Steatosis": (0, 255, 0),
}


def mask_to_json(mask_path):
    img = cv2.imread(mask_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    patient_name = os.path.splitext(os.path.basename(mask_path))[0]
    result = {}

    for color, label in color_to_class.items():
        binary_mask = np.all(img_rgb == np.array(color), axis=-1).astype(np.uint8) * 255
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if len(contours) == 0:
            continue

        objects = []
        valid_id = 0
        for cnt in contours:
            coords = cnt.squeeze()
            if coords.ndim == 1:
                coords = np.expand_dims(coords, 0)
            if coords.shape[0] < 3:
                continue

            objects.append({
                "Coordinates": [[float(x), float(y)] for [x,y] in coords],
                "Patient": patient_name,
                "Id": valid_id
            })
            valid_id += 1

        result[label] = objects
    return result

def masks_folder_to_json(folder_path, output_json, batch_size=100, max_images_per_json=10000):
    files = [f for f in os.listdir(folder_path) if f.lower().endswith(".png")]
    total_files = len(files)
    print(f"Total images found: {total_files}")

    base_name,ext = os.path.splitext(output_json)

    json_counter = 1
    processed_images = 0
    batch_counter = 0

    all_data = {label: [] for label in class_to_color.keys()}
    temp_data = {label: [] for label in class_to_color.keys()}

    for filename in tqdm(files, desc="Processing Masks", unit='img'):
        mask_path = os.path.join(folder_path, filename)
        try:
            data = mask_to_json(mask_path)
            for label, objects in data.items():
                temp_data[label].extend(objects)
        except Exception as e:
            print(f"Error with {filename}: {e}")
            continue

        batch_counter += 1
        processed_images += 1

        if batch_counter >= batch_size:
            for label, objects in temp_data.items():
                all_data[label].extend(objects)
            temp_data = {label: [] for label in class_to_color.keys()}
            batch_counter = 0

        if processed_images >= max_images_per_json:
            json_filename = f"{base_name}_{json_counter}{ext}"
            with open(json_filename, 'w') as f:
                json.dump(all_data, f, separators=(',', ':'))
            print(f"Saved {json_filename} ({processed_images} images)")

            all_data = {label: [] for label in class_to_color.keys()}
            temp_data = {label: [] for label in class_to_color.keys()}
            processed_images = 0
            json_counter += 1

    if any(len(v)>0 for v in all_data.values()) or any(len(v) > 0 for v in temp_data.values()):
        for label, objects in temp_data.items():
            all_data[label].extend(objects)
        json_filename = f"{base_name}_{json_counter}{ext}"
        with open(json_filename, 'w') as f:
            json.dump(all_data, f, separators=(',', ':'))
        print(f"Saved last file: {json_filename}")

mask_folder = "C:/Users/giuli/Documents/Uni/Tesi/Data/semantic/train/7"

json_path = os.path.join(folder, "contours.json")


masks_folder_to_json(mask_folder, json_path)
