import os
import cv2
import numpy as np
import tifffile as tiff
import xlsxwriter
from natsort import natsorted

def Normalization(image_dir,output_image_dir):
    image_files = [f for f in sorted(os.listdir(image_dir)) if f.endswith('.tif')]

    # Create output directories if they don't already exist

    if not os.path.isdir(output_image_dir):
        os.makedirs(output_image_dir)
    info = []
    i = 0
    # Loop through each image
    for image_name in image_files:
        # Read the image
        image_path = os.path.join(image_dir, image_name)
        image = np.array(tiff.imread(image_path))

        # Calculate the 1st and 99.9th percentiles of the grayscale
        p1 = np.percentile(image, 1)
        p95 = np.percentile(image, 99.9)
        p95 = round(p95, 3)
        info.append([image_name, p1, p95])
        #immagine = {'file name': image_name, 'p1': p1, 'p95': p95}
        #info.append(immagine)
        # Normalize the image using the calculated percentiles
        image_normalized = (image - p1) / (p95 - p1)


        # Clip the values to be between 0 and 1 after normalization
        image_normalized = np.clip(image_normalized, 0, 1)

        # Convert the normalized image to uint8
        image_normalized = (image_normalized * 255).astype(np.uint8)

        # Save the normalized image
        output_image_path = os.path.join(output_image_dir, image_name)
        cv2.imwrite(output_image_path, image_normalized)
        i = i+1
        print(image_name)

    file_excel = xlsxwriter.Workbook('Percentili Python.xlsx')
    foglio_excel = file_excel.add_worksheet()

    # Write the header row
    header = ['Image Name', '1st Percentile', '99.9th Percentile']
    for col_num, col_name in enumerate(header):
        foglio_excel.write(0, col_num, col_name)

    # Write the image information
    for row_num, row_data in enumerate(info, start=1):
        for col_num, col_data in enumerate(row_data):
            foglio_excel.write(row_num, col_num, col_data)

    # Close the Excel file
    file_excel.close()

    print('Processing complete. Information saved to Percentili Python.xlsx.')

def Primo_Terzo_Alignment(dataset_dir, output_dir,pivot_image_name,save_interval = 10):
    # Creare la cartella di output se non esiste
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    jpeg_quality=95
    # Ottenere l'elenco delle immagini TIFF nella cartella
    list_image = natsorted([f for f in sorted(os.listdir(dataset_dir)) if f.endswith('.tif')])
    # Leggere il frame di riferimento (seconda immagine)
    reference_frame = np.asarray(tiff.imread(os.path.join(dataset_dir, list_image[1])))
    # Salvare la prima immagine 
    cv2.imwrite(os.path.join(output_dir, list_image[1]), reference_frame, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality])
    
    n = 0

    # Ciclo per processare le immagini successive
    for i in list_image:
        if pivot_image_name == i:                
                frame = np.asarray(tiff.imread(os.path.join(dataset_dir, i)))
                cv2.imwrite(os.path.join(output_dir, i), frame, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality])
                reference_frame = frame.copy()
                save_interval = 5
                n = save_interval  # Forzare il cambio al prossimo ciclo
                print(f'Frame di riferimento cambiato a immagine {i}')
                continue
        frame = np.asarray(tiff.imread(os.path.join(dataset_dir, i)))
        # Trova la trasformazione per allineare le immagini
        criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 100, 1e-5)
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try:
            cc, warp_matrix = cv2.findTransformECC(reference_frame, frame, warp_matrix, cv2.MOTION_TRANSLATION, criteria)
        except cv2.error as e:
            print(f"Alignment non riuscito per l'immagine {i}: {e}")
            cv2.imwrite(os.path.join(output_dir, i), frame, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality])
            n += 1
            if n % save_interval == 0:
                reference_frame = frame.copy()
                print(f'Frame di riferimento cambiato a immagine {i}')
            continue

        # Applica la trasformazione per allineare le immagini
        aligned_image = cv2.warpAffine(frame, warp_matrix, (reference_frame.shape[1], frame.shape[0]),
                                       flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)

        # Salvare l'immagine allineata
        cv2.imwrite(os.path.join(output_dir, i), aligned_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality])

        # Cambia il frame di riferimento ogni "save_interval" immagini
        if n % save_interval == 0:
            reference_frame = aligned_image.copy()
            print(f'Frame di riferimento cambiato a immagine {i}')
            
        n += 1
        print(f'Immagine {i} processata e salvata.')

def Patches(image_folder, patch_dim_ver, patch_dim_hor, output_dir):
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        image_files = [f for f in sorted(os.listdir(image_folder)) if f.endswith('.tif')]
        image_files = natsorted(image_files)
        for image_name in image_files:
            image = tiff.imread(os.path.join(image_folder,image_name))
            image_height, image_width = image.shape[:2]
            patch_dim_ver_1 = image_width//2 +  patch_dim_ver//2
            patch_dim_ver_2 = image_width//2 +  patch_dim_ver//2
            patch_dim_hor_1 = image_height//2
            patch_dim_hor_2 = image_height//2 - patch_dim_hor
            
            n_patch_hor = np.ceil(image_height / patch_dim_hor).astype(int) 
            n_patch_ver = np.ceil(image_width / patch_dim_ver).astype(int) + 1
            r = n_patch_hor//2
            c = n_patch_ver//2 + 1
            c_d = n_patch_ver//2 + 2
            r_s = n_patch_hor//2 +1
            
            x = patch_dim_ver_1 - patch_dim_ver
            y = patch_dim_hor_1 + patch_dim_hor
            y1 =patch_dim_hor_1
            xd = patch_dim_ver_2 + patch_dim_ver
            for i in range(0 , n_patch_ver//2+1):
                for n in range(0, np.ceil(n_patch_hor/2).astype(int)):
                    output_image_dir = os.path.join(output_dir, f'{r}/{r}_{c}')
                    if not os.path.isdir(output_image_dir):
                        os.makedirs(output_image_dir)
                    if patch_dim_hor_2 < 0:
                       patch_s = image[0:patch_dim_hor_1,x:patch_dim_ver_1]
                       patch_d = image[0:patch_dim_hor_1,patch_dim_ver_2:xd]
                       
                    else: 
                        patch_s = image[patch_dim_hor_2:patch_dim_hor_1,x:patch_dim_ver_1]
                        patch_d = image[patch_dim_hor_2:patch_dim_hor_1,patch_dim_ver_2:xd]
                        
                    image_name_s = image_name[:-4] + '_' + str(r) + '_' + str(c) + image_name[-4:]
                    output_image_path = os.path.join(output_image_dir, image_name_s)
                    cv2.imwrite(output_image_path, patch_s)
                    #print(str(r) + '_' + str(c))
                    output_image_dir = os.path.join(output_dir, f'{r}/{r}_{c_d}')
                    if not os.path.isdir(output_image_dir):
                        os.makedirs(output_image_dir)
                    if c_d == n_patch_ver + 1:
                        print('fine destra')
                    else: 
                        image_name_d = image_name[:-4] + '_' + str(r) + '_' + str(c_d) + image_name[-4:]
                        output_image_path = os.path.join(output_image_dir, image_name_d)
                        cv2.imwrite(output_image_path, patch_d)
                        #print(str(r) + '_' + str(c_d))
                    
                    r = r-1
                    patch_dim_hor_1 = patch_dim_hor_2
                    patch_dim_hor_2 = patch_dim_hor_2 - patch_dim_hor

                    if y > image_height:
                       y = image_height
                       patch_s = image[y1:y,x:patch_dim_ver_1]
                       patch_d = image[y1:y,patch_dim_ver_2:xd]
                       
                       
                    else: 
                        patch_s = image[y1:y,x:patch_dim_ver_1]
                        patch_d = image[y1:y,patch_dim_ver_2:xd]
                    
                        y1 = y
                        y = y + patch_dim_hor
                    
                    image_name_n = image_name[:-4] + '_' + str(r_s) + '_' + str(c) + image_name[-4:]
                    output_image_dir = os.path.join(output_dir, f'{r_s}/{r_s}_{c}')
                    if not os.path.isdir(output_image_dir):
                        os.makedirs(output_image_dir)
                    output_image_path = os.path.join(output_image_dir, image_name_n)
                    cv2.imwrite(output_image_path, patch_s)
                    print(str(r_s) + '_' + str(c))
                    if c_d == n_patch_ver + 1:
                        print('fine destra')
                    else: 
                        image_name_d = image_name[:-4] + '_' + str(r_s) + '_' + str(c_d) + image_name[-4:]
                        output_image_dir = os.path.join(output_dir, f'{r_s}/{r_s}_{c_d}')
                        if not os.path.isdir(output_image_dir):
                            os.makedirs(output_image_dir)
                        output_image_path = os.path.join(output_image_dir, image_name_d)
                        cv2.imwrite(output_image_path, patch_d)
                        #print(str(r_s) + '_' + str(c_d))
                    r_s = r_s + 1
                
                patch_dim_ver_1 = x
                patch_dim_ver_2 = xd
                patch_dim_hor_1 = image_height//2
                x = patch_dim_ver_1 - patch_dim_ver
                xd = patch_dim_ver_2 + patch_dim_ver
                y = patch_dim_hor_1 + patch_dim_hor
                patch_dim_hor_2 = image_height//2 - patch_dim_hor
                y1 =patch_dim_hor_1
                if  x < 0:
                     x = 0
                elif  xd >= image_width:
                    xd = image_width
            
                c = c-1
                c_d = c_d+1
                r = n_patch_hor//2
                r_s = n_patch_hor//2 +1

