import ast
import numpy as np
import xarray as xr
import pandas as pd
from pyproj import Proj, Transformer, transform #to fix issue considering that the reference systems were not explicitly stated in the objects
from shapely.geometry import box #added later o to get the bbox
import geopandas as gpd #to work wtih shp file (agricultural areas)

import time
from datetime import datetime, timedelta
import math


###################################################################################################################################################
#function that expects a tuple of (min_lon,min_lat,max_lon,max_lat).
#it rounds down the minimum values and rounds up the maximum values up to the defined precision (default precision is two digits after the decimal point)
def round_boundaries(boundaries_tuple,precision=2):
    min_lon,min_lat,max_lon,max_lat = boundaries_tuple
    # Minimum values are rounded down
    min_lon = (int(min_lon * 10**precision)) / (10**precision)
    min_lat = (int(min_lat * 10**precision)) / (10**precision)
    # Maximum values are rounded up
    max_lon = ( (int(max_lon * 10**precision)) + 1 ) / (10**precision)
    max_lat = ( (int(max_lat * 10**precision)) + 1 ) / (10**precision)

    # Rounded coordinates are returned by the function
    return (min_lon,min_lat,max_lon,max_lat)
###################################################################################################################################################

###################################################################################################################################################
def get_target_grid_boundaries(boundaries_tuple,sample_size=500,get_inner_boundaries=True):
    target_boundaries = []
    boundaries_tuple = round_boundaries(boundaries_tuple,precision=0)
    x_min, y_min, x_max, y_max = boundaries_tuple
    coordinate_tuples = [(x_min,x_max),(y_min,y_max)]
    for coordinate_tuple in coordinate_tuples:  # Loop through each coordinate pair # EG: (103,2570)
        min_coord,max_coord = coordinate_tuple
        coord_range = max_coord - min_coord  # Get range of this pair #EG: 2570 - 103 = 2467

        # First, we extend the boundary coordinates so that the difference between them is a multiple of our target sample size
        range_remainder = coord_range % sample_size # Get remainder from division by the sample_size #EG: 2467 % 500 = 467
        range_to_be_added = sample_size - range_remainder # Get range to be added #EG: 500 - 467 = 33
        range_to_be_added_per_boundary = int(range_to_be_added/2) # Split remainder in two #EG: int(33/2) = int(16.5) = 16
        extended_min_coord = min_coord - range_to_be_added_per_boundary # Extend the minimum coordinate by half the remainder #EG: 103 - 16 = 87
        extended_max_coord = max_coord + range_to_be_added_per_boundary + (range_to_be_added%2) # Extend the max coordinate by other half of the remainder (+1 if remainder was an odd number)
                                                                                                #EG: 2570 + 16 + 1 = 2570 + 17 = 2587

        # Next, we remove half of sample size from each extended boundary to get target grid's boundary coordinates
        target_min_coord = extended_min_coord + int(get_inner_boundaries)*(sample_size/2) # Remove half of sample size from one end to get coordinate of first pixel/cell in our target grid #EG: 87 + 250 = 337
        target_max_coord = extended_max_coord - int(get_inner_boundaries)*(sample_size/2) # Remove half of sample size from other end to get coordinate of last pixel/cell in our target grid #EG: 2587 - 250 = 2337
        #NOTE: if "get_inner_boundaries"=False, then we'd get the outermost UTM cooridinates to be considered for our target grid
        target_coordinate_pair = (target_min_coord,target_max_coord) #EG: (337, 2337)
        target_boundaries.append(target_coordinate_pair)

    # Finally, we change ordering of results to mimic ordering of the input tuple
    [(target_x_min,target_x_max),(target_y_min,target_y_max)] = target_boundaries
    res = (target_x_min,target_y_min,target_x_max,target_y_max)

    return res
###################################################################################################################################################

######################################################################################################################################################
def generate_target_grid(grid_boundaries,sample_spacing=500,return_axes_flag=False):
    x_min,y_min,x_max,y_max = grid_boundaries
    # Generate x values
    x_values = np.arange(x_min, x_max + sample_spacing, sample_spacing)
    # Generate y values
    y_values = np.arange(y_min, y_max + sample_spacing, sample_spacing)

    if return_axes_flag:
        return x_values,y_values
    else:
        # Create an empty dataset with UTM coordinates
        empty_dataset_utm = xr.Dataset(coords={'x_utm': x_values, 'y_utm': y_values})

        return empty_dataset_utm
######################################################################################################################################################

#######################################################################################################################################
def generate_geo_dataset(grid_boundaries, data_df, initialize_to_nan=False, coord_column_name='coord_pair',data_column_name='correlation_coefficient', res_data_column_name='pearson_correlation'):
    # Check if each element in the column is a tuple and make them into tuple dtype if they aren't already
    if not all(isinstance(val, tuple) for val in data_df[coord_column_name]):
        data_df[coord_column_name] = data_df[coord_column_name].apply(ast.literal_eval) #convert 'coords' values from string representations of tuples to actual tuples  

    # Generate the target grid (in the form of an empty dataset)
    empty_dataset_utm = generate_target_grid(grid_boundaries)
    
    # Create an empty data array to be filled with values from the appropriate column in the  data_df
    if initialize_to_nan:
        data_array = np.full((len(empty_dataset_utm['y_utm']), len(empty_dataset_utm['x_utm'])),np.nan)
    else:
        data_array = np.zeros((len(empty_dataset_utm['y_utm']), len(empty_dataset_utm['x_utm'])))
    
    # Iterate through data dataframe to fill values into the data array
    for index, row in data_df.iterrows():
        x_idx = np.where(empty_dataset_utm['x_utm'].values == row[coord_column_name][0])[0][0]
        y_idx = np.where(empty_dataset_utm['y_utm'].values == row[coord_column_name][1])[0][0]
        #fill data array with values from df column containing data variable
        data_array[y_idx, x_idx] = row[data_column_name]
    
    # Add the data array to the (initially) empty dataset as a data_variable
    empty_dataset_utm[res_data_column_name] = (('y_utm', 'x_utm'), data_array)

    return empty_dataset_utm #actually, now it's not empty since it has one data variable which we added
#######################################################################################################################################

###################################################################################################################################################
def transform_coordinates(coordinates, utm_to_wgs84=True, utm_zone=32):
    """
    Function that transforms coordinates between UTM and WGS84.

    Parameters:
    - coordinates: Tuple of a coordinate pair (horizontal_coordinate,vertical_coordinate) OR a list of coordinate-pair tuples.
    - utm_to_wgs84: Boolean flag indicating the direction of the transformation.
                    If True, transform from UTM to WGS84; if False, transform from WGS84 to UTM.
    Returns:
    - Transformed coordinates as a list of two-element tuples [(lon1, lat1), (lon2,lat2), ...] for WGS84 or [(x1,y1), (x2,y2), ...] for UTM.
    """
    results = [] #initiate empty list where each transformed coordinate tuple will be stored

    # Create UTM projection for a specific zone (e.g., Zone 32)
    utm_proj = Proj(proj='utm', zone=utm_zone, ellps='WGS84')
    # Create a Transformer
    if utm_to_wgs84:
        transformer = Transformer.from_proj(utm_proj, "EPSG:4326", always_xy=True)
    else:
        transformer = Transformer.from_proj("EPSG:4326", utm_proj, always_xy=True)


    if isinstance(coordinates,tuple) and len(coordinates)==2:
        coordinates = [coordinates] #make it a one-element list (containing just one tuple)
    elif not isinstance(coordinates,list):
        return None

    for coordinate_tuple in coordinates:
        # Transform coordinates
        transformed_coordinate_tuple = transformer.transform(*coordinate_tuple)
        results.append(transformed_coordinate_tuple)

    return results

## Example usage:
#utm_coords = [(500000, 4649775), (505000, 4650000)]
#wgs84_coords = transform_coordinates(utm_coords, utm_to_wgs84=True)
#print("UTM to WGS84:", wgs84_coords)

## Reverse transformation
#back_to_utm_coords = transform_coordinates(wgs84_coords, utm_to_wgs84=False)
#print("WGS84 back to UTM:", back_to_utm_coords)
###################################################################################################################################################


###################################################################################################################################################
#CHECKING SPATIAL BOUNDARIES OF AGRICULTURAL AREAS (AOI)
def get_aoi_boundaries(print_results=False,wgs84_flag=True):
    # Load shapefile representing the area of interest
    aoi_path = r"C:\Users\Mario\OneDrive - Politecnico di Torino\Mario Chalouhy - Thesis\04-Checking different Radarsat products\CanaleCaluso_agri_areas\CanaleCaluso_agri.shp"
    gdf = gpd.read_file(aoi_path)

    # Change CRS of the AOI data from UTM to WGS84
    #1-Extract the bounding box from the GeoDataFrame
    bbox = gdf.total_bounds
    #2-Create a bounding box geometry using shapely
    bbox_geometry = box(bbox[0], bbox[1], bbox[2], bbox[3])
    #3-Create a GeoDataFrame with the bounding box
    bbox_gdf = gpd.GeoDataFrame(geometry=[bbox_geometry], crs=gdf.crs)
    #4-Get the extent of the bounding box in UTM
    minx, miny, maxx, maxy = bbox_gdf.bounds.values[0]
    if wgs84_flag==True:
        [(min_lon, min_lat), (max_lon, max_lat)] = transform_coordinates([(minx,miny),(maxx,maxy)])
        if print_results:
            print(f'min_lon: {min_lon}')
            print(f'min_lat: {min_lat}')
            print(f'max_lon: {max_lon}')
            print(f'max_lat: {max_lat}')
        return (min_lon,min_lat,max_lon,max_lat)
    else:
        if print_results:
            print(f'min_x: {minx}\nmin_y: {miny}\nmax_x: {maxx}\nmax_y: {maxy}')
        return (minx,miny,maxx,maxy)

#END OF CHECKING SPATIAL BOUNDARIES OF AGRICULTURAL AREAS (AOI)
###################################################################################################################################################



##################################################################################
## FUNCTIONS INTRODUCE ONLY FROM RT1 FITTING:
##################################################################################

#######################################################################################################################################
def custom_resample(subset_df,window_size=12):
    # Extract constant values for 'x_UTM', 'y_UTM', 'lon', 'lat', 'coords'
    first_row = subset_df.iloc[0]
    x_UTM = first_row['x_UTM']
    y_UTM = first_row['y_UTM']
    lon = first_row['lon']
    lat = first_row['lat']
    coords = first_row['coords']
    min_date = subset_df['date'].min()
    max_date = subset_df['date'].max()
    # Calculate the number of days between min_date and max_date
    num_days = (max_date - min_date).days
    # Calculate the number of windows needed
    num_windows = math.ceil(num_days / window_size)
    # Calculate the new max_date
    new_num_days = window_size * num_windows
    new_max_date = min_date + timedelta(days=new_num_days)
    # Create an empty list to store row values
    rows = []
    
    # Resample the first row without aggregation
    first_row_values = first_row.to_dict()
    rows.append(first_row_values)
    
    # Iterate over each date range for resampling
    for i in range(1, new_num_days+1, window_size):
        # Define the start and end dates for the current window
        start_date = min_date + timedelta(days=i)
        end_date = start_date + timedelta(days=window_size-1)
        # Resample the current window
        window_df = subset_df.loc[subset_df['date'].between(start_date, end_date)]
        mean_N = window_df['N'].mean()
        # Create a dictionary for the row values
        row_values = {'date': end_date, 'x_UTM': x_UTM, 'y_UTM': y_UTM,
                      'lon': lon, 'lat': lat, 'coords': coords, 'N': mean_N}
        rows.append(row_values)
    
    # Create a DataFrame from the list of row values
    resampled_df = pd.DataFrame(rows)
    #resampled_df.set_index('date',inplace=True)
    #print('resampled_df:')
    #print(resampled_df.head())
    #print('\noriginal_df:')
    ## Print the first 15 rows
    #print(subset_df.iloc[:15])

    return resampled_df
#######################################################################################################################################