import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import Voronoi, voronoi_plot_2d
import os
import cv2
from PIL import Image
from skimage.measure import regionprops, label
from skimage.morphology import remove_small_objects
import diplib as dip
from skimage.draw import polygon_perimeter
from skimage.draw import polygon as skpolygon
from openpyxl import load_workbook, Workbook
from Feature_Segmentazione import features
from natsort import natsorted
import openpyxl

def curve_creation_nuclei(count, bins):
        count[1] = 0
    # Weighted average
        th_start = int(round(np.sum(count * bins) / np.sum(count)))

    # Find maxValue
    #count_rev = np.flip(count)
        ind = np.where(count != 0)
        #print(ind)
        #print(ind[0][0])
        if len(ind) != 0:
            maxValue = 255 - ind[0][0] + 1
        else:
            th_start = 0
            maxValue = 0

        mean_low = []
        mean_high = []
        #n=0
        for i in range(th_start, maxValue-1):
            mean_low.append(int(round(np.sum(count[:i] * bins[:i]) / np.sum(count[:i]))))
                
            r = np.sum(count[i + 1:len(count) - 1] * bins[i + 1:len(count) - 1])
            k= np.sum(count[i + 1:len(count) - 1])
            if r == 0 and k == 0:
                mean_high.append(0)
            else:
                mean_high.append(int(round(np.sum(count[i + 1:len(count) - 1] * bins[i + 1:len(count) - 1]) / np.sum(count[i + 1:len(count) - 1]))))
            #print(n)
            #n=n+1
        mean_diff = np.array(mean_high) - np.array(mean_low)
        index = np.where((mean_diff - mean_low) > 0)[0]

        if len(index) != 0:
            thresh = np.uint8(th_start + index[0])
        else:
            thresh = maxValue

        LOW_curve = mean_low
        DIFF_curve = mean_diff
        HIGH_curve = mean_high

        return LOW_curve, DIFF_curve, HIGH_curve, th_start, thresh

def iterate_voronoi_regions(vor):
        for region_index in vor.point_region:
            region = vor.regions[region_index]
            if not -1 in region and len(region) > 0:
                polygon = [vor.vertices[i] for i in region]
                yield polygon

def segmentazione(list_images,result_dir,dataset_dir,riga,colonna,output,Feature=False):

    info = []
    # Analyze all images
    for filename in list_images:
        # Read original image
        image = np.array(Image.open(os.path.join(dataset_dir, filename)))

        # Se l'immagine ha più di 2 dimensioni, converti in scala di grigi
        if image.ndim == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        elif image.ndim == 4:  # Per immagini con canale alfa
            image = cv2.cvtColor(image[:, :, :3], cv2.COLOR_BGR2GRAY)

        # Ora puoi ottenere le dimensioni
        rows, cols = image.shape
        print(filename)
        # RAW Segmentation
        if np.sum(image == 0) == rows * cols:
            output_image_path = os.path.join(result_dir, filename)
            cv2.imwrite(output_image_path, image)
            continue

        count, bins = np.histogram(image.flatten(), 255, [0, 255])

        count = np.insert(count, 0, 0)
        _, _, _, _, thresh = curve_creation_nuclei(count, bins)

        raw_binary_mask = np.zeros_like(image)
        if thresh > 0:
            raw_binary_mask[image >= thresh] = 1
        else:
            print("Threshold not found for image:", filename)
            continue

        raw_binary_mask = raw_binary_mask.astype(bool)
        raw_binary_mask = remove_small_objects(raw_binary_mask, min_size=5)
        raw_binary_mask = raw_binary_mask.astype(int)

        # Centroids detection
        image_dip = dip.Image(image)
        maxima_map = dip.Maxima(image_dip, connectivity=2, output='binary')  # connectivity = 2 -> 8 connesso
        maxima_map = np.array(maxima_map)

        centroids = [region.centroid for region in regionprops(label(maxima_map))]
        centroids = np.array(centroids)
        x = np.round(centroids[:, 0]).astype(int)
        y = np.round(centroids[:, 1]).astype(int)
        raw_binary_mask = raw_binary_mask.astype(np.uint8)
    #    contours, _ = cv2.findContours(raw_binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        

        coordinates = np.zeros_like(image)
        # Controlliamo se le coordinate sono presenti nella matrice immagine
        for i in range(len(x)):
            if raw_binary_mask[x[i], y[i]] == 1:
                coordinates[x[i], y[i]] = 1

        coordinates = np.where(coordinates == 1)
        centroids_in_mask = np.column_stack((coordinates[1], coordinates[0]))
        #print('ok')
        # Voronoi diagrams
        if len(centroids_in_mask) > 0:
            vor = Voronoi(centroids_in_mask)
        else:
            print("Nessun punto trovato per il diagramma di Voronoi.")

        # Funzione di plotting di Voronoi
        #segments = Postprocessing.voronoi_segmentation(vor)

        #fig, ax = plt.subplots()
        #voronoi_plot_2d(vor, ax=ax, show_vertices=False, line_colors='orange', line_width=0.5, line_alpha=0.6, point_size=1)

        maschera = np.zeros_like(image)
        #print('ok')
        # Iterazione sulle regioni di Voronoi
        for polygon in iterate_voronoi_regions(vor):
            polygon = np.array(polygon)
            rr, cc = skpolygon(polygon[:, 1], polygon[:, 0], maschera.shape)

            # Applica la maschera alla maschera finale
            maschera[rr, cc] = raw_binary_mask[rr,cc]
            rr, cc = polygon_perimeter(polygon[:, 1], polygon[:, 0], shape=maschera.shape, clip=True)

            # Applica la maschera alla maschera finale
            maschera[rr, cc] = 0
        #print('ok')
        #Salva la maschera risultante
        output_image_path = os.path.join(result_dir, filename)
        cv2.imwrite(output_image_path, maschera.astype(np.uint8) * 255)
        if Feature:
            labeled_mask = label(maschera, connectivity=1)

        # Calcola le proprietà delle cellule, incluso l'area
            properties = regionprops(labeled_mask)

            cell_areas = [prop.area for prop in properties]

        # Numero di cellule
            num_cellule = len(cell_areas)

            total_area = sum(cell_areas)
            mean_area = total_area / num_cellule if num_cellule > 0 else 0

            perimeters = [prop.perimeter for prop in properties]
            mean_per = sum(perimeters) / num_cellule if num_cellule > 0 else 0

            #circolarita_media = np.mean([4 * np.pi * prop.area / (prop.perimeter ** 2) for prop in properties])
            eccentricita_media = np.mean([prop.eccentricity for prop in properties])
            orientamento_medio = np.mean([prop.orientation for prop in properties])

            info.append([filename, num_cellule, mean_area, mean_per, eccentricita_media, orientamento_medio])
    if Feature:        
        features(output,info,riga,colonna,list_images)
        