import geopandas as gpd
import pandas as pd
import requests
import base64
import utm
import pytz
import pickle
import os
import operator
import re
import json

from shapely.geometry import Point
from shapely.wkt import loads
from shapely.ops import nearest_points
from django.db import connections, transaction, models
from django.conf import settings
from django.db.models import Q, F, Subquery, OuterRef, Min, Max, DateField, Avg
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from scipy.interpolate import UnivariateSpline
from datetime import datetime, timedelta, date
from django.views.decorators.csrf import csrf_exempt
from statistics import mode, StatisticsError
from functools import reduce
from core.models import variables
from core.utils import fill_api_parameters, load_SQL_data, save_SQL_data
from PPA.utils import convert_time_naive
from CAL.models import CalendarEvent
from .models import Feeder, WeatherStation, Weather, WeatherForecast, \
    WeatherTest, Load, Outage

def get_weather_and_load(starttime, endtime):
    table_name = 'feeder_weather_training'
    query = f'SELECT * FROM {table_name}'
    query += f' WHERE time >= "{starttime}" AND time <= "{endtime}"'

    combined_data = load_SQL_data('temp_data.db', table_name, query, ['time'])

    if combined_data is None:
        combined_data = get_combined_data(starttime, endtime)

    if combined_data is None or combined_data.shape[0] == 0:
        print("No Data exists for time period")
        return {'message': 'Empty Query'}
    else:
        parameter_columns = [
            'P', 'Q', 'V_AB', 'V_BC', 'V_AC', 'temperature', 'humidity', 'irradiance',
            'rainfall', 'windspeed', 'winddirection', 'riverlevel',
        ]
        columns_to_return = ['time', 'FeederName'] + parameter_columns
        combined_data = combined_data[columns_to_return]
        data_json = combined_data.to_json(date_format='iso', orient='records')
        data_json = {'message': 'Valid Query', 'results': data_json}
        return data_json


@csrf_exempt
def train_feeder_models_LF(starttime, endtime):
    table_name = 'feeder_weather_training'
    columns_to_select = ['feeder_id', 'time', 'P', 'Q', 'temperature', 'humidity', 'irradiance',
                         'windspeed', 'winddirection', 'rainfall', 'riverlevel', 'hour', 'week_no',
                         'P_hr_ago', 'Q_hr_ago', 'P_day_ago', 'Q_day_ago','event_number', 'P_avg_3h',
                         'Q_avg_3h', 'P_avg_1d', 'Q_avg_1d', 'outage_energy']
    columns_str = ', '.join(columns_to_select)
    query = f'SELECT {columns_str} FROM {table_name}'
    combined_data = load_SQL_data('temp_data.db', table_name, query, ['time'])
    if combined_data is None:
        combined_data = get_combined_data()
        combined_data = add_calendar_events(combined_data)
        combined_data = add_outages(combined_data)

        combined_data = combined_data.drop(['id_x', 'id_y', 'type', 'P_forecast', 'Q_forecast'], axis=1)
        # add additional metrics that influence the load forecast
        combined_data['hour'] = combined_data['time'].dt.hour
        combined_data['week_no'] = combined_data['time'].dt.isocalendar().week

        # Set multi-level index for feeder_id and time so that all hours of the day are included
        combined_data.set_index(['feeder_id', 'time'], inplace=True)

        all_feeders = combined_data.index.get_level_values('feeder_id').unique()
        all_dates = pd.date_range(start=combined_data.index.get_level_values('time').min(),
                                  end=combined_data.index.get_level_values('time').max(), freq='H')
        new_index = pd.MultiIndex.from_product([all_feeders, all_dates], names=['feeder_id', 'time'])
        combined_data = combined_data.reindex(new_index)

        # P and Q from the previous hour
        combined_data['P_hr_ago'] = combined_data.groupby(level='feeder_id')['P'].shift(1)
        combined_data['Q_hr_ago'] = combined_data.groupby(level='feeder_id')['Q'].shift(1)

        # P and Q from the previous day at the same hour
        combined_data['P_day_ago'] = combined_data.groupby(level='feeder_id')['P'].shift(24)
        combined_data['Q_day_ago'] = combined_data.groupby(level='feeder_id')['Q'].shift(24)

        # Shifted data for rolling computation
        P_shifted = combined_data.groupby('feeder_id')['P'].shift(1)
        Q_shifted = combined_data.groupby('feeder_id')['Q'].shift(1)

        # Average P and Q values for the past 3 hours excluding current record
        combined_data['P_avg_3h'] = P_shifted.groupby('feeder_id').rolling(window=3, min_periods=1).mean().reset_index(
            level=0, drop=True)
        combined_data['Q_avg_3h'] = Q_shifted.groupby('feeder_id').rolling(window=3, min_periods=1).mean().reset_index(
            level=0, drop=True)

        # Average P and Q values for the past 1 day (24 hours) excluding current record
        combined_data['P_avg_1d'] = P_shifted.groupby('feeder_id').rolling(window=24, min_periods=1).mean().reset_index(
            level=0, drop=True)
        combined_data['Q_avg_1d'] = Q_shifted.groupby('feeder_id').rolling(window=24, min_periods=1).mean().reset_index(
            level=0, drop=True)
        nans_to_zero = ['P_hr_ago', 'Q_hr_ago', 'P_day_ago', 'Q_day_ago', 'P_avg_3h', 'Q_avg_3h', 'P_avg_1d', 'Q_avg_1d']
        for column in nans_to_zero:
            combined_data[column].fillna(0, inplace=True)
        combined_data.reset_index(inplace=True)

        save_SQL_data(combined_data, 'temp_data.db', 'feeder_weather_training')

    build_feeder_models(combined_data)

def add_calendar_events(combined_data, start_date_utc=None, end_date_utc=None):
    # Initially annotate combined_data with a null 'event_type' field
    combined_data['event_number'] = 0
    event_type_mapping = {
        'holiday': 1,
        'outage': 2,
        # add more mappings as needed
    }
    if start_date_utc is not None and end_date_utc is not None:
        start_date_localized = pytz.timezone('UTC').localize(start_date_utc)
        end_date_localized = pytz.timezone('UTC').localize(end_date_utc)
        for event in CalendarEvent.objects.filter(start__lte=end_date_localized, end__gte=start_date_localized):
            event_number = event_type_mapping.get(event.type, None)
            if event_number is not None:
                mask = (combined_data['time'] >= event.start) & (combined_data['time'] < event.end)
                combined_data.loc[mask, 'event_number'] = event_number
    else:
        for event in CalendarEvent.objects.all():
            event_number = event_type_mapping.get(event.type, None)
            if event_number is not None:
                mask = (combined_data['time'] >= event.start) & (combined_data['time'] < event.end)
                combined_data.loc[mask, 'event_number'] = event_number
    return combined_data

def add_outages(combined_data, start_date_utc=None, end_date_utc=None):

    # Convert Outage model to DataFrame
    if start_date_utc is not None and end_date_utc is not None:
        start_date_localized = pytz.timezone('UTC').localize(start_date_utc)
        end_date_localized = pytz.timezone('UTC').localize(end_date_utc)
        outages_df = pd.DataFrame(list(Outage.objects.filter(
            time_out__lte=end_date_localized, time_in__gte=start_date_localized
        ).values()))
    else:
        outages_df = pd.DataFrame(list(Outage.objects.all().values()))

    # LC to load_identifier lookup dictionary
    lc_lookup = {
        'Ldv': 'LDV',
        'All': 'All',
        'Bmp': 'BMP',
        'Owk': 'OWK',
        'Czl': 'CZL',
        'Spd': 'SPR',
        'Ind': 'IND',
        'Dan': 'DGA',
        'Bze': 'BZE',
        'Pga': 'PGA',
        'Sig': 'SIG',
        'Cck': 'CCK',
        'Wst': 'WST',
        'Mul': 'MUL',
    }
    outages_df['LC'] = outages_df['LC'].replace(lc_lookup)

    # Get relationship between load_identifier and feeder_id once
    feeders = Feeder.objects.values('load_identifier', 'FeederName')
    lc_to_feeders = {}
    for f in feeders:
        if f['load_identifier'] not in lc_to_feeders:
            lc_to_feeders[f['load_identifier']] = []
        lc_to_feeders[f['load_identifier']].append(f['FeederName'].strip()[-1])

    feedername_to_id = {f['FeederName'].strip(): f['id'] for f in Feeder.objects.values('FeederName', 'id')}
    load_typical_lookup = {f['FeederName'].strip(): f['load_typical'] for f in
                           Feeder.objects.values('FeederName', 'load_typical')}
    # For LC='All', replicate the row for each feeder across all load centers
    all_lc_df = outages_df[outages_df['LC'] == 'All'].copy()
    outages_df = outages_df[outages_df['LC'] != 'All']

    for _, row in all_lc_df.iterrows():
        total_typical_load = sum([load_typical_lookup.get(lc + feeder_name.zfill(2), 0)
             for lc in lc_to_feeders for feeder_name in lc_to_feeders[lc]])
        if total_typical_load == 0:
            total_typical_load = 1
        for lc, feeders_for_lc in lc_to_feeders.items():
            replicated_df = pd.concat([row] * len(feeders_for_lc), axis=1).transpose()
            replicated_df.reset_index(drop=True, inplace=True)
            replicated_df['LC'] = lc
            replicated_df['feeder'] = feeders_for_lc
            for idx, feeder_no in enumerate(feeders_for_lc):
                feeder_key = lc + feeder_no.zfill(2)  # Construct the key like 'BZE01', 'BZE02', etc.
                replicated_df.loc[idx, 'energy'] *= load_typical_lookup.get(feeder_key, 0) / total_typical_load
            outages_df = pd.concat([outages_df, replicated_df])

    # For 'All' feeders, replicate the row for each feeder under the LC
    all_feeders_df = outages_df[outages_df['feeder'] == 'All'].copy()
    outages_df = outages_df[outages_df['feeder'] != 'All']
    for _, row in all_feeders_df.iterrows():
        feeders_for_lc = lc_to_feeders.get(row['LC'], [])
        total_typical_load = sum([load_typical_lookup.get(row['LC'] + feeder.zfill(2), 0) for feeder in feeders_for_lc])
        if total_typical_load == 0:
            total_typical_load = 1
        replicated_df = pd.concat([row] * len(feeders_for_lc), axis=1).transpose()
        replicated_df.reset_index(drop=True, inplace=True)
        replicated_df['feeder'] = feeders_for_lc
        for idx, feeder_no in enumerate(feeders_for_lc):
            feeder_key = row['LC'] + feeder_no.zfill(2)  # Construct the key like 'BZE01', 'BZE02', etc.
            replicated_df.loc[idx, 'energy'] *= load_typical_lookup.get(feeder_key, 0) / total_typical_load
        outages_df = pd.concat([outages_df, replicated_df])

    # Adjust time_out to match the hour precision
    outages_df['original_time_out'] = outages_df['time_out']
    outages_df['original_time_in'] = outages_df['time_in']
    outages_df['time_out'] = pd.to_datetime(outages_df['time_out'])
    outages_df['time_in'] = pd.to_datetime(outages_df['time_in'])
    outages_df['original_time_out'] = pd.to_datetime(outages_df['original_time_out'])
    outages_df['original_time_in'] = pd.to_datetime(outages_df['original_time_in'])

    # Adjust time_out to match the hour precision
    outages_df['time_out'] = outages_df['time_out'].dt.floor('H')
    outages_df['time_in'] = outages_df['time_in'].dt.floor('H')

    # Expand the dataframe to cover each hour during the outage
    date_ranges = outages_df.apply(lambda row: pd.date_range(start=row['time_out'], end=row['time_in'], freq='H'),
                                   axis=1)
    # Explode the list into separate rows
    outages_df = outages_df.assign(current_hour=date_ranges).explode('current_hour').reset_index(drop=True)

    # Calculate the weighted energy
    outages_df['weighted_energy'] = outages_df.apply(
        lambda row: (row['energy'] * get_fraction_of_hour(row['original_time_out'], row['original_time_in'],
                                                          row['current_hour'].hour)) / 1000, axis=1)
    outages_df['weighted_energy'].fillna(0, inplace=True)
    # Rename time and current_hour column
    outages_df.rename(columns={
        'current_hour': 'time',
        'weighted_energy': 'outage_energy'
    }, inplace=True)

    # Combine LC and feeder to get feeder_id
    outages_df['FeederName'] = outages_df['LC'].str.strip() + outages_df['feeder'].str.strip().str.zfill(2)
    outages_df['feeder_id'] = outages_df['FeederName'].map(feedername_to_id)
    outages_df.drop(['FeederName', 'minutes', 'time_out', 'time_in', 'id', 'energy'], axis=1, inplace=True)
    aggregation = {
        'LC': 'first',
        'feeder': 'first',
        'zone': 'first',
        'outage_energy': 'sum',
        'original_time_out': 'first',
        'original_time_in': 'first',
        'source': 'first',
        'type': 'first',
    }

    # Group by 'feeder_id' and 'time', then aggregate
    outages_df = outages_df.groupby(['feeder_id', 'time']).agg(aggregation)
    outages_df.reset_index(inplace=True)
    outages_df['feeder_id'] = outages_df['feeder_id'].astype(int)
    combined_data = pd.merge(combined_data, outages_df[['feeder_id', 'time', 'outage_energy']],
                             on=['feeder_id', 'time'], how='left')
    combined_data['outage_energy'].fillna(0, inplace=True)

    return combined_data


def build_feeder_models(combined_data):
    # Now combined_data contains the merged dataset
    # Perform any necessary feature engineering

    # Create separate datasets for each feeder and train models
    unique_feeders = combined_data['feeder_id'].unique()
    empty_model = Feeder.objects.filter(trained_model__isnull=True).values_list('id',flat=True)
    #feeders = [feeder_id for feeder_id in unique_feeders if feeder_id in empty_model]

    for feeder_id in unique_feeders:
        feeder = Feeder.objects.get(id=feeder_id)
        feeder_data = combined_data[combined_data['feeder_id'] == feeder_id].copy()
        weather_station = WeatherStation.objects.get(id=feeder.weather_station_id)
        trained_feature_list = weather_station.Training_fields.split(',')
        feeder_data.dropna(subset=trained_feature_list, how='any', inplace=True)
        feeder_name = feeder.FeederName
        num_records = len(feeder_data)
        time_min = feeder_data['time'].min()
        time_max = feeder_data['time'].max()
        print(f"Feeder: {feeder_name} has {num_records} load + weather records used in training.")

        if feeder_id in empty_model and num_records > 1000:
            # Create features and targets
            additional_elements = ['hour', 'week_no', 'P_hr_ago', 'Q_hr_ago', 'P_day_ago', 'Q_day_ago',
                                   'event_number', 'P_avg_3h', 'Q_avg_3h', 'P_avg_1d', 'Q_avg_1d', 'outage_energy']
            trained_feature_list.extend(additional_elements)
            X = feeder_data[trained_feature_list]
            y = feeder_data[['P', 'Q']]

            # Split data into training and validation sets
            X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

            # Normalize features6
            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_val = scaler.transform(X_val)

            # Create a feedforward neural network model
            model = keras.Sequential([
                keras.layers.Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
                keras.layers.Dense(32, activation='relu'),
                keras.layers.Dense(2)  # Output layer for predicting both P and Q
            ])

            # Compile the model
            model.compile(optimizer='adam', loss='mean_squared_error')

            # Train the model
            model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))

            # Serialize and save the trained model
            serialized_model = pickle.dumps(model)

            # Find or create a Feeder instance for this feeder_id
            feeder_instance = Feeder.objects.get(id=feeder_id)

            if feeder_instance:
                # Save the trained model to the Feeder instance
                feeder_instance.trained_model = serialized_model
                feeder_instance.trained_start = time_min
                feeder_instance.trained_end = time_max
                feeder_instance.trained_id = feeder.weather_station_id
                feeder_instance.trained_records = num_records
                feeder_instance.trained_features = ','.join(trained_feature_list)
                feeder_instance.save()

def get_fraction_of_hour(original_time_out, original_time_in, current_hour):
    if current_hour == original_time_out.hour:
        return (60 - original_time_out.minute) / 60
    elif current_hour == original_time_in.hour:
        return original_time_in.minute / 60
    else:
        return 1
def fill_missing_values(group, column, engine):
    training_horizon = 40  # the number of records before and after NaN span to use for training
    max_valid_span = 24  # the number of continuous NaN records that are acceptable
    min_valid_span_irr = 6  # the number of continuous NaN records in irradiance to be filled with 0
    max_valid_span_irr = 16  # the number of continuous NaN records in irradiance to be filled with 0
    # Identify spans of NaN values
    nan_spans = get_nan_spans(group[column])

    # For the 'irradiance' column, fill spans of NaNs less than or equal to 12 with 0
    if column == 'irradiance':
        for start, end in nan_spans:
            # Extract the actual timestamp from the tuple
            start_time = start[1]
            end_time = end[1]
            hour_difference = (end_time - start_time).total_seconds() / 3600 + 1
            # If the difference is less than or equal to 12, fill with 0s
            if min_valid_span_irr < hour_difference <= max_valid_span_irr:
                group.loc[start:end, column] = 0

    # Filter spans that have less than or equal to 24 NaN values
    valid_spans = []
    for start, end in nan_spans:
        # Extract the actual timestamp from the tuple
        start_time = start[1]
        end_time = end[1]

        # Calculate the difference in hours between the start and end timestamps
        hour_difference = (end_time - start_time).total_seconds() / 3600 + 1

        # If the difference is less than or equal to 24, add it to valid_spans
        if 0 < hour_difference <= max_valid_span:
            valid_spans.append((start, end))

    # For each valid span, fill in missing values
    first_index = group.index.get_loc(group.index[0])
    last_index = group.index.get_loc(group.index[-1])
    previous_valid_spans = [(None, None)] + valid_spans[:-1]
    extended_valid_spans = valid_spans + [(None, None)]
    next_valid_spans = valid_spans[1:] + [(None, None)]

    for (prev_start, prev_end), (start, end), (next_start, next_end) in zip(previous_valid_spans, extended_valid_spans,
                                                                            next_valid_spans):
        if start is None or end is None:
            continue
        position_start = group.index.get_loc(start)
        position_end = group.index.get_loc(end)
        if prev_end:
            prev_position_end = group.index.get_loc(prev_end)
            adjusted_start = max(first_index, position_start - training_horizon, prev_position_end)
        else:
            adjusted_start = max(first_index, position_start - training_horizon)
        if next_start:
            next_position_start = group.index.get_loc(next_start)
            adjusted_end = min(last_index, position_end + training_horizon, next_position_start)
        else:
            adjusted_end = min(last_index, position_end + training_horizon)

        train_data = group.loc[pd.IndexSlice[group.index[adjusted_start]:group.index[adjusted_end]]]

        if engine == "ML":
            group = filler_engine_ML(train_data, column, group, start, end)
        elif engine == "spline":
            group = filler_engine_spline(train_data, column, group, start, end)
        else:
            group = filler_engine_interpolate(train_data, column, group)

    return group

def filler_engine_ML(train_data, column, group, start, end):
    columns_for_features = ['P', 'Q', 'temperature', 'humidity', 'irradiance',
                            'windspeed', 'winddirection', 'hour', 'week_no']
    if column in columns_for_features:
        columns_for_features.remove(column)
    # Features are other columns except the one we're predicting
    features = group.columns.intersection(columns_for_features)

    train_data = train_data[train_data[column].notnull()]

    # Get the missing data span
    test_data = group.loc[pd.IndexSlice[start:end]]
    test_data = test_data.dropna(subset=columns_for_features)

    # Train a simple linear regression model
    lr = LinearRegression()
    lr.fit(train_data[features], train_data[column])

    # Predict the missing values
    predicted_values = lr.predict(test_data[features])

    # Fill NaN values with the predicted values
    group.loc[test_data.index, column] = predicted_values

    return group
def filler_engine_interpolate(train_data, column, group):
    # Interpolate the missing values in train_data
    interpolated_values = train_data[column].interpolate(limit_direction='both')

    # Fill NaN values in the group using the interpolated values from train_data
    indices_to_update = train_data.index
    group.loc[indices_to_update, column] = interpolated_values

    return group


def filler_engine_spline(train_data, column, group, start, end):
    # Extract the timestamp from the second level of the MultiIndex and convert to int
    time_index = train_data.index.get_level_values(1).astype('int64')

    # Only consider non-NaN values for generating the spline
    non_na_data_values = train_data[column].dropna().values
    non_na_time_index = time_index[train_data[column].notna()]

    # Create a spline representation of the data.
    # The "s" parameter can be adjusted to control the smoothness. Set to 0 for interpolation.
    try:
        # Attempt to create a spline
        spline = UnivariateSpline(non_na_time_index, non_na_data_values, s=5)
    except Exception as e:
        # If an error occurs, print the values of non_na_time_index and non_na_data_values
        print(f"Error occurred for column '{column}': {str(e)}. attempting to interpolate.")
        print("non_na_time_index:", non_na_time_index)
        print("non_na_data_values:", non_na_data_values)
        group = filler_engine_interpolate(train_data, column, group)
        return group

    # Calculate the interpolated values
    start_time = start[1].value  # convert to int64 for pandas Timestamp
    end_time = end[1].value
    indices_to_interpolate = time_index[(time_index >= start_time) & (time_index <= end_time)]
    interpolated_values = spline(indices_to_interpolate)

    # Update the group DataFrame
    multi_indices_to_update = train_data.loc[start:end].index
    group.loc[multi_indices_to_update, column] = interpolated_values

    return group
def get_nan_spans(s):
    # Get the locations where the series is NaN and where it changes
    is_na = s.isnull()
    flag_diff = is_na.ne(is_na.shift())

    # Get start and end of consecutive NaNs
    idx_start = s.index[flag_diff & is_na].tolist()
    idx_end = s.index[flag_diff.shift(-1) & is_na].tolist()

    return list(zip(idx_start, idx_end))


def get_feeder_nearest_ws():
    # Fields to exclude from the querysets
    exclude_feeder_fields = ['Shape', 'trained_model', 'trained_start', 'trained_end', 'trained_records', 'trained_features']

    # Fetch the fields that you want to include
    feeder_fields = [f.name for f in Feeder._meta.get_fields() if f.name not in exclude_feeder_fields]

    # Subquery to fetch weather station details using the weather_station_id in Feeder
    weather_station_subquery = WeatherStation.objects.filter(
        id=OuterRef('weather_station_id')
    ).values(
        'Name',
        'Latitude',
        'Longitude'
    )[:1]

    # Annotating the Feeder queryset with the related WeatherStation's details.
    feeders_with_stations = Feeder.objects.only(*feeder_fields).annotate(
        weather_station_name=Subquery(weather_station_subquery.values('Name')),
        weather_station_latitude=Subquery(weather_station_subquery.values('Latitude')),
        weather_station_longitude=Subquery(weather_station_subquery.values('Longitude'))
    ).values(
        *feeder_fields,
        'weather_station_name',
        'weather_station_latitude',
        'weather_station_longitude',
    )

    return feeders_with_stations


def feeder_spatial_join():
    result_df = calc_nearest_weather_station(1, 10000)
    replace_weather_station_relationship(result_df)
    unique_weather_station_ids = result_df['weather_station_id'].unique()
    WeatherStation.objects.exclude(id__in=unique_weather_station_ids).update(valid=0)
    WeatherStation.objects.filter(id__in=unique_weather_station_ids).update(valid=1)

    # Delete the temporary database file so it is rebuilt
    file_to_delete = os.path.join(os.getcwd(), 'temp_data.db')
    if os.path.exists(file_to_delete):
        # If the file exists, delete it
        os.remove(file_to_delete)
        print(f"The file '{file_to_delete}' has been deleted.")
    else:
        # If the file does not exist, print a message indicating this
        print(f"No file found at '{file_to_delete}'.")

    # Fill in missing weather data and store in local sql db
    clean_weather_data()

def calc_nearest_weather_station(has_data=None, data_count_min=None):
    # Create GeoDataFrame from the feeders and weather stations
    feeders_gdf, feeder_fields = get_feeder_centroid()

    excluded_stations = WeatherStation.objects.filter(exclude=True)
    exclude_ids = [station.id for station in excluded_stations]

    weather_stations_gdf, weather_station_fields = get_weather_station_coordinates(has_data, data_count_min, exclude_ids)

    # Create a DataFrame to store the results
    result_df = pd.DataFrame(
        columns=['feeder_id'] + feeder_fields[1:] + ['weather_station_id'] + weather_station_fields[:] + ['distance'])

    # Iterate over each feeder's centroid and find the closest weather station
    for feeder_id, centroid in feeders_gdf[['id', 'centroid']].itertuples(index=False):
        nearest_station = nearest_points(centroid, weather_stations_gdf.unary_union)[1]
        distance = centroid.distance(nearest_station)
        feeder_row = feeders_gdf.loc[feeders_gdf['id'] == feeder_id].iloc[0]
        station_row = weather_stations_gdf.loc[weather_stations_gdf.geometry == nearest_station].iloc[0]
        result_df = result_df._append({
            'feeder_id': feeder_id,
            **feeder_row.drop(['geometry', 'id']).to_dict(),
            'weather_station_id': station_row['id'],
            **station_row.drop(['geometry', 'X', 'Y', 'id']).to_dict(),
            'distance': distance
        }, ignore_index=True)

    return result_df


def get_feeder_centroid():

    # Fetch only the required fields from the Feeder model
    feeder_fields = ['id', 'FeederName', 'Shape']

    feeders = Feeder.objects.only(*feeder_fields)

    # Convert the feeders to a list of dictionaries
    feeders_data = list(feeders.values(*feeder_fields))

    # Create GeoDataFrame from the feeders
    feeders_gdf = gpd.GeoDataFrame(feeders_data, geometry=[loads(feeder['Shape']) for feeder in feeders_data])

    # Calculate the centroids of feeders
    feeders_gdf['centroid'] = feeders_gdf['geometry'].centroid
    return feeders_gdf, feeder_fields


def get_weather_station_coordinates(has_data=None, data_count_min=None, exclude_ids=None):
    if has_data is None or data_count_min is None:
        weather_stations = WeatherStation.objects.all()
    else:
        weather_stations = WeatherStation.objects.filter(has_data=has_data, data_count__gt=data_count_min)
    if exclude_ids:
        weather_stations = weather_stations.exclude(id__in=exclude_ids)
    weather_stations_data = []
    weather_station_fields = ['Name', 'Latitude', 'Longitude']

    for weather_station in weather_stations:
        weather_station_dict = {'id': weather_station.id, 'X': weather_station.X, 'Y': weather_station.Y}
        for field_name in weather_station_fields:
            weather_station_dict[field_name] = getattr(weather_station, field_name)
        weather_stations_data.append(weather_station_dict)

    # Create GeoDataFrame from the weather stations
    weather_stations_gdf = gpd.GeoDataFrame(weather_stations_data,
                                            geometry=[Point(station.X, station.Y) for station in weather_stations])
    return weather_stations_gdf, weather_station_fields


def pd_to_context(queryset):
    structured_results = []

    for record in queryset:
        # Converting each record to a dictionary and append it to the list
        structured_results.append(record)

    # Return the results as a list of dictionaries
    return structured_results


def replace_weather_station_relationship(results_pd):

    # Iterate through rows in results_pd and update the fields in the Feeder model
    for index, row in results_pd.iterrows():
        feeder_id = row['feeder_id']
        weather_station_id = row['weather_station_id']
        distance = row['distance']

        # Update or create the instance in the Feeder model
        Feeder.objects.update_or_create(
            id=feeder_id,
            defaults={
                'weather_station_id': weather_station_id,
                'distance': distance
            }
        )

def update_feeder_loads():
    feeders = Feeder.objects.all()

    for feeder in feeders:
        # Find the last date for this feeder
        latest_date = Load.objects.filter(feeder_id=feeder.id).aggregate(Max('time'))['time__max']
        if not latest_date:
            print(f"No data available for feeder {feeder.id}")
            continue

        # Calculate the start date, three months before the latest date
        start_date = latest_date - timedelta(days=180)
        start_date = start_date.replace(tzinfo=pytz.UTC)
        latest_date = latest_date.replace(tzinfo=pytz.UTC)

        # Fetch daily maximum loads for the feeder within the time window
        daily_max_values = (Load.objects.filter(feeder_id=feeder.id, time__range=(start_date, latest_date))
                            .annotate(
            day=F('time__date'))
                            .values('day')
                            .annotate(
            daily_max=Max('P'))
                            .values_list('daily_max', flat=True))

        # Find the mode of daily maximums
        try:
            typical_max = mode(daily_max_values)

        except StatisticsError:
            # In case there's no number that appears more than once, we handle it by taking the average.
            typical_max = sum(daily_max_values) / len(daily_max_values) if daily_max_values else 0

        # Save the typical max value into the feeder instance
        feeder.load_typical = typical_max
        feeder.save()

        print(
            f"Saved typical max {typical_max} for feeder {feeder.id} considering data from {start_date} to {latest_date}")
def update_station_statistics(weather_station_id=None):
    if weather_station_id is None:
        weather_stations = WeatherStation.objects.all()
    else:
        weather_stations = WeatherStation.objects.filter(id=weather_station_id)
    fields = ['temperature', 'humidity', 'irradiance', 'windspeed', 'winddirection', 'rainfall', 'riverlevel']
    for weather_station in weather_stations:
        weather_minute_count = Weather.objects.filter(weather_station_id=weather_station.id).count()
        dates = Weather.objects.filter(weather_station_id=weather_station.id).aggregate(
            earliest_date=Min('time'),
            latest_date=Max('time'),
        )
        station_earliest_date = dates['earliest_date']
        station_latest_date = dates['latest_date']
        # Update historical_data and historical_records fields
        weather_station.has_data = 1 if weather_minute_count > 0 else 0
        weather_station.data_count = weather_minute_count
        weather_station.data_start = station_earliest_date
        weather_station.data_end = station_latest_date
        if weather_minute_count > 0:
            null_counts = []
            for field in fields:
                null_count = Weather.objects.filter(weather_station_id=weather_station.id, **{f"{field}__isnull": True}).count()
                if field == 'temperature':
                    out_of_range_count = Weather.objects.filter(weather_station_id=weather_station.id,
                                        temperature__isnull=False).exclude(temperature__range=(12, 40)).count()
                    null_count += out_of_range_count

                percentage = (null_count / weather_minute_count) * 100
                if null_count > 0:
                    if percentage < 1:
                        formatted_percentage = "{:.3f}".format(percentage)
                    elif 1 <= percentage < 10:
                        formatted_percentage = "{:.1f}".format(percentage)
                    else:
                        formatted_percentage = "{:.0f}".format(percentage)
                    null_counts.append(f"{field}={formatted_percentage}%")
            if null_counts:
                weather_station.null_fields = ",".join(null_counts)

        weather_station.save()


def summarize_weather_by_hour():
    # Group the records by weather_station_id and hour, and calculate the summaries
    with connections['local'].cursor() as cursor:
        cursor.execute("""
                SELECT
                    weather_station_id,
                    strftime('%Y-%m-%d %H', time) AS day_hour,
                    AVG(temperature) AS average_temperature,
                    AVG(humidity) AS average_humidity,
                    AVG(irradiance) AS average_irradiance,
                    AVG(windspeed) AS average_windspeed,
                    AVG(winddirection) AS average_winddirection,
                    SUM(rainfall) AS sum_rainfall,
                    AVG(riverlevel) AS average_riverlevel,
                    COUNT(*) AS total_records
                FROM
                    weather
               GROUP BY
                   weather_station_id, day_hour
            """)

        # Fetch the results and create objects
        summary_objects = []
        for row in cursor.fetchall():
            weather_station_id, day_hour, avg_temp, avg_humidity, avg_irradiance, avg_windspeed, avg_winddirection, sum_rainfall, avg_riverlevel, total_records = row
            summary_objects.append({
                'weather_station_id': weather_station_id,
                'time': day_hour,  # This will include both date and hour
                'temperature': avg_temp,
                'humidity': avg_humidity,
                'irradiance': avg_irradiance,
                'windspeed': avg_windspeed,
                'winddirection': avg_windspeed,
                'rainfall': sum_rainfall,
                'riverlevel': avg_riverlevel,
            })

        # Use the 'local' database for deleting and creating objects
        Weather.objects.using('local').all().delete()

        # Create new Weather objects with the summarized data using bulk_create
        weather_objects_to_create = [
            Weather(
                weather_station_id=summary['weather_station_id'],
                time=datetime.strptime(summary['time'], '%Y-%m-%d %H').strftime('%Y-%m-%d %H:%M:%S'),
                # Use the summarized hour as the time
                temperature=summary['temperature'],
                humidity=summary['humidity'],
                irradiance=summary['irradiance'],
                windspeed=summary['windspeed'],
                winddirection=summary['winddirection'],
                rainfall=summary['rainfall'],
                riverlevel=summary['riverlevel'],
            )
            for summary in summary_objects
        ]

        # Bulk insert the new Weather objects
        Weather.objects.using('local').bulk_create(weather_objects_to_create)


def get_combined_data(start_time=None, end_time=None):
    feeders_with_stations = get_feeder_nearest_ws()
    feeder_ws_df = pd.DataFrame(feeders_with_stations)[
        ['id', 'FeederName', 'weather_station_id', 'weather_station_name']]
    feeder_ws_df.rename(columns={'id': 'feeder_id'}, inplace=True)

    # Define the query sets with optional time filters
    query = f'SELECT * FROM weather_data'
    if start_time and end_time:
        query += f' WHERE time >= "{starttime}" AND time <= "{endtime}"'
    else:
        if start_time:
            query += f' WHERE time >= "{starttime}"'
        if end_time:
            query += f' WHERE time <= "{endtime}"'

    weather_data = load_SQL_data('temp_data.db', 'weather_data', query, ['time'])

    if weather_data is None:
        weather_query = Weather.objects
        if start_time and end_time:
            weather_query = weather_query.filter(time__gte=start_time, time__lte=end_time)
        else:
            if start_time:
                weather_query = weather_query.filter(time__gte=start_time)
            if end_time:
                weather_query = weather_query.filter(time__lte=end_time)
        weather_data = pd.DataFrame(weather_query.all().values())
        if weather_data.empty:
            print("weather_data is empty")
            return None

    # if start and end not provided only return loads in weather timespan
    load_query = Load.objects
    if not start_time or not end_time:
        weather_time_min = weather_data['time'].min()
        weather_time_max = weather_data['time'].max()
        load_query = load_query.filter(time__gte=weather_time_min, time__lte=weather_time_max)
    else:
        load_query = load_query.filter(time__gte=start_time, time__lte=end_time)
    load_data = pd.DataFrame(load_query.all().values())
    if load_data.empty:
        print("load_data is empty")
        return None

    # Ensure unique time indices
    weather_data.drop_duplicates(subset=['time', 'weather_station_id'], inplace=True)
    load_data.drop_duplicates(subset=['time', 'feeder_id'], inplace=True)

    # Merge the data based on feeder_id and weather_station_id
    combined_data = pd.merge(load_data, feeder_ws_df, on='feeder_id')

    # Merge the combined data with weather data based on time
    combined_data = pd.merge(combined_data, weather_data, on=['time', 'weather_station_id'], how='left')
    combined_data.dropna(
        subset=['temperature', 'humidity', 'irradiance', 'windspeed', 'winddirection', 'rainfall', 'riverlevel', 'P'],
        how='all', inplace=True)
    return combined_data

def clean_weather_data():
    # Define the query sets with optional time filters
    weather_query = Weather.objects
    weather_data = pd.DataFrame(weather_query.all().values())

    # clear any values outside logical range
    weather_data.loc[weather_data['temperature'] < 12, 'temperature'] = np.nan
    weather_data.loc[weather_data['temperature'] > 40, 'temperature'] = np.nan
    weather_data.loc[weather_data['irradiance'] > 1150, 'irradiance'] = 1150

    # Ensure unique time indices
    weather_data.drop_duplicates(subset=['time', 'weather_station_id'], inplace=True)
    weather_data.dropna(
        subset=['temperature', 'humidity', 'irradiance', 'windspeed', 'winddirection', 'rainfall'],
        how='all', inplace=True)
    weather_data.set_index(['weather_station_id', 'time'], inplace=True)
    all_stations = WeatherStation.objects.filter(valid=1).values_list('id', flat=True).distinct()
    all_dates = pd.date_range(start=weather_data.index.get_level_values('time').min(),
                              end=weather_data.index.get_level_values('time').max(), freq='H')
    new_index = pd.MultiIndex.from_product([all_stations, all_dates], names=['weather_station_id', 'time'])
    weather_data = weather_data.reindex(new_index)

    # Identify columns with missing data & fill in the missing values using ML
    columns_to_impute = ['temperature', 'humidity', 'irradiance',
                         'windspeed', 'winddirection', 'rainfall']
    for column in columns_to_impute:
        if weather_data[column].isnull().any():
            # Only apply fill_missing_values function to groups that have NaN values for the specific column
            mask = weather_data.groupby(level='weather_station_id').apply(lambda g: g[column].isnull().any())
            valid_stations = mask[mask].index.tolist()

            # Set the index using only one index that matches the structure
            new_index = weather_data.index.get_level_values(0).isin(valid_stations)
            filtered_data = weather_data.loc[new_index]
            filtered_data = filtered_data.groupby(level='weather_station_id').apply(
                lambda group: fill_missing_values(group, column, 'spline'))
            filtered_data.reset_index(level=0, inplace=True)
            weather_data.loc[filtered_data.index, column] = filtered_data[column]

    weather_data.loc[weather_data['rainfall'] < 0, 'rainfall'] = 0
    weather_data['rainfall'].fillna(0, inplace=True)
    weather_data.loc[weather_data['irradiance'] < 0, 'irradiance'] = 0
    weather_data.loc[weather_data['irradiance'] > 1150, 'irradiance'] = 1150
    # Reset index and drop NaN rows (those are the ones that were added during reindexing)
    weather_data.reset_index(inplace=True)
    weather_data.dropna(
        subset=['temperature', 'humidity', 'irradiance', 'windspeed', 'winddirection'],
        how='all', inplace=True)
    save_SQL_data(weather_data, 'temp_data.db', 'weather_data')

def get_feeder_model(feeder_id):
    try:
        feeder_instance = Feeder.objects.get(id=feeder_id)
        if feeder_instance.trained_model:
            # Deserialize and return the trained model
            return pickle.loads(feeder_instance.trained_model)
    except Feeder.DoesNotExist:
        pass  # Handle the case where the feeder_id doesn't exist in your database
    return None  # Return None if no model is found for the given feeder_id


def get_feeder_load_forecasts(start_time_utc, end_time_utc):
    # Loop through each feeder
    feeders = Feeder.objects.only('id', 'weather_station_id', 'trained_features').all()

    for feeder in feeders:
        # Get the corresponding trained model
        model = get_feeder_model(feeder.id)
        if model:
            forecasts = retrieve_weather_forecast(start_time_utc, end_time_utc, feeder.weather_station_id)
            weather_forecast = pd.DataFrame(forecasts)

            adjusted_start_time = start_time_utc - timedelta(hours=24)
            load_query = Load.objects.filter(time__gte=pytz.timezone('UTC').localize(adjusted_start_time),
                                             time__lte=pytz.timezone('UTC').localize(end_time_utc),
                                             feeder_id=feeder.id)
            load_data = pd.DataFrame(load_query.all().values())
            if load_data.empty:
                print("load_data is empty")
                return None
            merged_data = weather_forecast.merge(load_data, on='time', how='outer')
            merged_data = condition_load_data(merged_data, start_time_utc)

            if weather_forecast:
                # Initialize lists to store forecasted load and time
                model_features = [feature.strip() for feature in feeder.trained_features.split(',')]

                # Loop through each forecasted time step
                for index, data in merged_data.iterrows():
                    # Extract features needed for prediction
                    weather_features = data[model_features].values

                    # Predict the load using the model
                    predicted_P, predicted_Q = model.predict([weather_features])[0]

                    # Update the Load model
                    Load.objects.filter(time=data.time, feeder_id=feeder.id).update(
                        P_forecast=predicted_P, Q_forecast=predicted_Q)

            else:
                print(f"No weather forecast data available for feeder {feeder.feeder_id}.")
                return None
        else:
            print(f"No model found for feeder {feeder.feeder_id}.")
            return None

def construct_load_json(start_time_utc, end_time_utc):
    weather_data = {
        "weather": {}
    }

    feeders = Feeder.objects.only('id', 'weather_station_id')
    for feeder in feeders:
        feeder_id = feeder.id
        weather_station_id = feeder.weather_station_id
        forecasts = retrieve_weather_forecast(start_time_utc, end_time_utc, weather_station_id)
        time_series = {}
        for forecast in forecasts:
            timestamp = forecast['time'].strftime('%Y-%m-%d %H:%M:%S')
            weather_info = {
                "temperature": forecast['temperature'],
                "humidity": forecast['humidity'],
                "wind_speed": forecast['wind_speed'],
                "wind_direction": forecast['wind_direction'],
                "irradiance": forecast['irradiance'],
                "rainfall": forecast['rainfall'],
                "riverlevel": forecast['riverlevel']
            }
            time_series[timestamp] = weather_info

        weather_data["weather"][feeder_id] = time_series

    # Convert dictionary to JSON string
    weather_json = json.dumps(weather_data)

    return weather_json

def condition_load_data(load_data, start_time_utc, end_time_utc):
    load_data = add_calendar_events(load_data, start_date_utc, end_date_utc)
    load_data = add_outages(load_data, start_date_utc, end_date_utc)
    # Extract hour and week_no from the 'time' column
    load_data['hour'] = load_data['time'].dt.hour
    load_data['week_no'] = load_data['time'].dt.isocalendar().week

    # Ensure all hours are present in the DataFrame
    all_dates = pd.date_range(start=load_data['time'].min(), end=load_data['time'].max(), freq='H')
    load_data = load_data.set_index('time').reindex(all_dates).reset_index().rename(columns={'index': 'time'})

    # Compute P and Q values from the previous hour and previous day
    load_data['P_hr_ago'] = load_data['P'].shift(1)
    load_data['Q_hr_ago'] = load_data['Q'].shift(1)
    load_data['P_day_ago'] = load_data['P'].shift(24)
    load_data['Q_day_ago'] = load_data['Q'].shift(24)

    # Average P and Q values for the past 3 hours and past day
    load_data['P_avg_3h'] = load_data['P'].shift(1).rolling(window=3, min_periods=1).mean()
    load_data['Q_avg_3h'] = load_data['Q'].shift(1).rolling(window=3, min_periods=1).mean()
    load_data['P_avg_1d'] = load_data['P'].shift(1).rolling(window=24, min_periods=1).mean()
    load_data['Q_avg_1d'] = load_data['Q'].shift(1).rolling(window=24, min_periods=1).mean()

    # Remove records where time < start_time_utc
    load_data = load_data[load_data['time'] >= start_time_utc]
    return load_data

def get_feeder_weather_forecast(start_time_utc, end_time_utc):
    # Retrieve unique weather API names from the variables model
    api_names = variables.objects.filter(group='weather_forecast', type='API').values_list('name', flat=True).distinct()

    # Process each API name for token validity
    # Dictionary of functions for token retrieval
    function_mapping = {
        "meteomatics": get_meteomatics_token,
    }
    for api_name in api_names:
        if not check_token_valid(api_name):
            if api_name in function_mapping:
                function_mapping[api_name]()



    # Get the timezone defined in your Django settings
    TZ = pytz.timezone(settings.TIME_ZONE)

    # Dictionary of functions for API call
    function_mapping = {
        "meteomatics": meteomatics_api_call,
        "tomorrow.io": tomorrow_io_api_call,
        "visualcrossing": visual_crossing_api_call,
    }
    function_mapping_response = {
        "meteomatics": parse_meteomatics_data,
        "tomorrow.io": parse_tomorrow_io_data,
        "visualcrossing": parse_visual_crossing_data
    }

    unique_weather_stations = Feeder.objects.all().values_list('weather_station_id', flat=True).distinct()
    for station in unique_weather_stations:
        weather_station = WeatherStation.objects.filter(id=station).first()
        latitude = weather_station.Latitude
        longitude = weather_station.Longitude

        # Process each API name
        for api_name in api_names:
            need_refresh, new_start_time = need_forecast_refresh(start_time_utc, end_time_utc, station, api_name)
            if need_refresh:
                # Retrieve API URL and key for the current API name
                api_endpoint = variables.objects.get(group='weather_forecast', type='API', name=api_name).value
                api_param = variables.objects.get(group='weather_forecast', type='parameters', name=api_name).value
                api_key = variables.objects.get(group='weather_forecast', type='KEY', name=api_name).value

                if api_name in function_mapping:
                    url, headers = function_mapping[api_name](api_endpoint, api_param, api_key, latitude, longitude, new_start_time, end_time_utc, TZ)

                # Make the API request
                response = requests.get(url, headers=headers)
                if response.status_code == 200:
                    weather_data = response.json()
                    if api_name in function_mapping_response:
                        function_mapping_response[api_name](weather_data, weather_station.id, 'weather_station_id',
                                                            api_name, new_start_time, end_time_utc, TZ, WeatherForecast)
                else:
                    print("Error:", response.text, api_name )

    results = construct_weather_json(start_time_utc, end_time_utc)
    return results

def construct_weather_json(start_time_utc, end_time_utc):
    weather_data = {
        "weather": {}
    }

    feeders = Feeder.objects.only('id', 'weather_station_id')
    for feeder in feeders:
        feeder_id = feeder.id
        weather_station_id = feeder.weather_station_id
        forecasts = retrieve_weather_forecast(start_time_utc, end_time_utc, weather_station_id)
        time_series = {}
        for forecast in forecasts:
            timestamp = forecast['time'].strftime('%Y-%m-%d %H:%M:%S')
            weather_info = {
                "temperature": forecast['temperature'],
                "humidity": forecast['humidity'],
                "wind_speed": forecast['wind_speed'],
                "wind_direction": forecast['wind_direction'],
                "irradiance": forecast['irradiance'],
                "rainfall": forecast['rainfall'],
                "riverlevel": forecast['riverlevel']
            }
            time_series[timestamp] = weather_info

        weather_data["weather"][feeder_id] = time_series

    # Convert dictionary to JSON string
    weather_json = json.dumps(weather_data)

    return weather_json

def retrieve_weather_forecast(start_time_utc=None, end_time_utc=None, weather_station_id=None):
    query_args = {}

    if start_time_utc:
        query_args['time__gte'] = pytz.timezone('UTC').localize(start_time_utc)
    if end_time_utc:
        query_args['time__lte'] = pytz.timezone('UTC').localize(end_time_utc)
    if weather_station_id is not None:  # Explicitly check against None because id=0 is valid
        query_args['weather_station_id'] = weather_station_id

    forecasts = WeatherForecast.objects.filter(
        **query_args
    ).values('time').annotate(
        temperature=Avg('temperature'),
        humidity=Avg('humidity'),
        wind_speed=Avg('wind_speed'),
        wind_direction=Avg('wind_direction'),
        irradiance=Avg('irradiance'),
        rainfall=Avg('rainfall'),
        riverlevel=Avg('riverlevel')
    ).order_by('time')

    return forecasts
def need_forecast_refresh(start_date_utc, end_date_utc, id, name):
    total_hours = int((end_date_utc - start_date_utc).total_seconds()) // 3600 + 1
    hourly_intervals = [start_date_utc + timedelta(hours=i) for i in range(total_hours)]
    intervals_utc = [pytz.utc.localize(dt) for dt in hourly_intervals]
    now_utc = datetime.utcnow().replace(minute=0, second=0, microsecond=0)

    records_with_updated = WeatherForecast.objects.filter(
        Q(time__gte=pytz.timezone('UTC').localize(start_date_utc)) &
        Q(time__lte=pytz.timezone('UTC').localize(end_date_utc)) &
        Q(weather_station_id=id, api_name=name)
    ).order_by('updated').values('time', 'updated')


    latest_updated_date = None
    if records_with_updated:
        post_now_records = [record for record in records_with_updated if record['time'] >= pytz.timezone('UTC').localize(now_utc)]
        if post_now_records:
            last_updated_date = post_now_records[0]['updated'].replace(tzinfo=None)

    # Convert the records to a set for faster lookup
    records = [record['time'] for record in records_with_updated]
    records_set = set(records)

    # Identify missing intervals
    missing_intervals = [dt.replace(tzinfo=None) for dt in intervals_utc if dt not in records_set]

    if not missing_intervals:
        if end_date_utc < now_utc:
            return False, start_date_utc
        start_date_missing = min(start_date_utc, now_utc)
        if last_updated_date:
            next_hour = (last_updated_date + timedelta(hours=1))
            refresh = datetime.utcnow() >= next_hour
            return refresh, start_date_missing
        else:
            return False, start_date_missing
    else:
        if start_date_utc <= now_utc and end_date_utc >= now_utc:
            start_date_missing = min(missing_intervals[0],now_utc)
        else:
            start_date_missing = missing_intervals[0]
        return True, start_date_missing

def tomorrow_io_api_call(endpoint, api_param, apikey, latitude, longitude, startTime, endTime, TZ):
    fields = [
        "precipitationIntensity",
        "precipitationType",
        "windSpeed",
        "windGust",
        "humidity",
        "windDirection",
        "temperature",
        "temperatureApparent",
        "cloudCover",
        "cloudBase",
        "cloudCeiling",
        "weatherCode",
    ]
    units = "imperial" # choose the unit system, either metric or imperial
    timesteps = "1h" # set the timesteps, like "current", "1h" and "1d"
    timezone = "UTC" # specify the timezone, using standard IANA timezone format

    param_dict = {
        "latitude": latitude,
        "longitude": longitude,
        "fields": ','.join(fields),
        "units": units,
        "timesteps": timesteps,
        "start": startTime,
        "end": endTime,
        "timezone": timezone,
        "apikey": apikey,
    }

    tomorrow_url = endpoint + fill_api_parameters(api_param, param_dict)
    headers = {'accept': 'application/json'}

    return tomorrow_url, headers


def parse_tomorrow_io_data(api_response, location_id, location_name, api_name, startTime, endTime, TZ, ForecastModel):
    # Extract relevant data from the API response
    weather_forecast_records = []
    timelines = api_response['data']['timelines']
    for timeline in timelines:
        intervals = timeline['intervals']
        for interval in intervals:
            date_str = interval['startTime']
            date = datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%SZ")

            # Convert UTC time to the timezone from settings
            date = date.replace(tzinfo=pytz.utc).astimezone(TZ)

            # Extract meteorological values
            values = interval['values']
            temperature = unit_conversion('tomorrow_io', 'temperature', values['temperature'])
            humidity = values['humidity']
            wind_speed = unit_conversion('tomorrow_io', 'windspeed', values['windSpeed'])
            wind_direction = values['windDirection']
            rainfall = unit_conversion('tomorrow_io', 'rainfall', values['precipitationIntensity'])
            irradiance = None

            # Create a weather_forecast_record dictionary
            weather_forecast_record = {
                'date': date,
                'temperature': temperature,
                'humidity': humidity,
                'wind_speed': wind_speed,
                'wind_direction': wind_direction,
                'rainfall': rainfall,
                'irradiance': irradiance
            }
            # Append the record to the list
            weather_forecast_records.append(weather_forecast_record)

        # Now, you can loop through weather_forecast_records and create or update WeatherForecast objects
        records_to_add_or_update = []
        for record in weather_forecast_records:
            defaults = {
                'temperature': record['temperature'],
                'humidity': record['humidity'],
                'wind_speed': record['wind_speed'],
                'wind_direction': record['wind_direction'],
                'rainfall': record['rainfall'],
                'irradiance': record['irradiance']
            }
            date = record['date'].astimezone(pytz.UTC)
            record_data = {
                'time': record['date'],
                'api_name': api_name,
                'weather_station_id': location_id,
                'defaults': defaults
            }
            records_to_add_or_update.append(record_data)

        add_records_to_model(ForecastModel, records_to_add_or_update, api_name, location_name, location_id)


def meteomatics_api_call(endpoint, api_param, apikey, latitude, longitude, startTime, endTime, TZ):
    model = "mix"
    fields = [
        "t_2m:F",
        "wind_speed_10m:ms",
        "wind_dir_10m:d",
        "precip_1h:mm",
        "global_rad_5min:J",
        "relative_humidity_2m:p",
    ]
    username = variables.objects.get(group='weather_forecast', type='username', name='meteomatics').value
    password = variables.objects.get(group='weather_forecast', type='password', name='meteomatics').value
    credentials = base64.b64encode(f'{username}:{password}'.encode()).decode()
    # Set up headers with basic authorization
    headers = {
        'Authorization': f'Basic {credentials}'
    }
    param_dict = {
        "latitude": latitude,
        "longitude": longitude,
        "fields": ','.join(fields),
        "model": model,
        "start": startTime,
        "end": endTime,
        "apikey": apikey,
    }
    meteomatics_url = endpoint + fill_api_parameters(api_param, param_dict)

    return meteomatics_url, headers


def parse_meteomatics_data(api_response, location_id, location_name, api_name, startTime, endTime, TZ, ForecastModel):
    data = api_response['data']
    # Define a list to store weather forecast records
    weather_forecast_records = []
    date_list = data[0]['coordinates'][0]['dates']
    # Iterate through the dates for the first parameter (data[0].coordinates)
    for date_entry in date_list:
        date_str = date_entry['date']
        date = datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%SZ")
        # Convert UTC time to the timezone from settings
        date = date.replace(tzinfo=pytz.utc).astimezone(TZ)
        # Create a dictionary to store the weather forecast record
        weather_forecast_record = {
            'date': date,
        }

        # Iterate through all parameters and get the corresponding value for the date
        for param_entry in data:
            parameter = param_entry['parameter']
            param_value = None

            # Search for the date in the coordinates of the parameter
            for coord_entry in param_entry['coordinates']:
                # Check if the date_str exists in the list of dates for this parameter
                if any(date_entry['date'] == date_str for date_entry in coord_entry['dates']):
                    # If found, set the param_value and break
                    param_value = next(item['value'] for item in coord_entry['dates'] if item['date'] == date_str)
                    break

            # Add the parameter value to the record
            weather_forecast_record[parameter] = param_value

        # Append the record to the list
        weather_forecast_records.append(weather_forecast_record)

    records_to_add_or_update = []
    for record in weather_forecast_records:
        defaults = {
            'temperature': unit_conversion('meteomatics', 'temperature', record['t_2m:F']),
            'humidity': record['relative_humidity_2m:p'],
            'wind_speed': unit_conversion('meteomatics', 'windspeed', record['wind_speed_10m:ms']),
            'wind_direction': record['wind_dir_10m:d'],
            'rainfall': record['precip_1h:mm'],
            'irradiance': record['global_rad_5min:J'] / 300
        }
        date = record['date'].astimezone(pytz.UTC)
        record_data = {
            'time': record['date'],
            'api_name': api_name,
            'weather_station_id': location_id,
            'defaults': defaults
        }
        records_to_add_or_update.append(record_data)

    add_records_to_model(ForecastModel, records_to_add_or_update, api_name, location_name, location_id)


def get_meteomatics_token():
    # Your Meteomatics API username and password
    username = variables.objects.get(group='weather_forecast', type='username', name='meteomatics').value
    password = variables.objects.get(group='weather_forecast', type='password', name='meteomatics').value

    # Encode the username and password in base64
    credentials = base64.b64encode(f'{username}:{password}'.encode()).decode()

    # Set up headers with basic authorization
    headers = {
        'Authorization': f'Basic {credentials}'
    }

    # Make the API request to get the token
    response = requests.get('https://login.meteomatics.com/api/v1/token', headers=headers)

    if response.status_code == 200:
        data = response.json()
        token = data['access_token']
        current_time = datetime.now() + timedelta(hours=2)
        formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")

        # Store the token
        token_entry, created = variables.objects.update_or_create(
            name='meteomatics',
            group='weather_forecast',
            type='KEY',
            defaults={'value': token}
        )

        # Store the expiration date
        exp_date_entry, created = variables.objects.update_or_create(
            name='meteomatics',
            group='weather_forecast',
            type='exp_date',
            defaults={'value': formatted_time}
        )

    else:
        print('Something went wrong. Status Code:', response.status_code)


def check_token_valid(name):
    valid = True
    expiration_date = variables.objects.filter(name=name, group='weather_forecast', type='exp_date').first()
    if expiration_date:
        expiration_timestamp = datetime.strptime(expiration_date.value, '%Y-%m-%d %H:%M:%S').timestamp()
        current_timestamp = datetime.utcnow().timestamp()
        if current_timestamp < expiration_timestamp:
            valid = True
    return valid


def visual_crossing_api_call(endpoint, api_param, apikey, latitude, longitude, startTime, endTime, TZ):
    # Start building the API query URL with the base URL and location
    utc_offset_seconds = TZ.utcoffset(endTime).total_seconds()
    hours_to_subtract = int(utc_offset_seconds / 3600)
    endTime += timedelta(hours=hours_to_subtract)
    startTime += timedelta(hours=hours_to_subtract)

    param_dict = {
        "latitude": latitude,
        "longitude": longitude,
        "start": startTime,
        "end": endTime,
        "APIKey": apikey,
    }

    visual_crossing_url = endpoint + fill_api_parameters(api_param, param_dict)
    headers = {"token": apikey}

    return visual_crossing_url, headers


def parse_visual_crossing_data(api_response, location_id, location_name, api_name, startTime, endTime, TZ,
                               ForecastModel):
    weather_forecast_records = []
    for day in api_response['days']:
        date_str = day['datetime']
        date_day = datetime.strptime(date_str, "%Y-%m-%d")
        for data_point in day['hours']:
            time_str = data_point['datetime']
            date = datetime.combine(date_day, datetime.strptime(time_str, "%H:%M:%S").time())
            utc_offset_seconds = TZ.utcoffset(endTime).total_seconds()
            hours_to_add = int(utc_offset_seconds / 3600)
            date -= timedelta(hours=hours_to_add)
            if date >= startTime and date <= endTime:
                # Convert UTC time to the timezone from settings
                date = date.replace(tzinfo=pytz.utc).astimezone(TZ)
                weather_forecast_record = {
                    'date': date,
                    'temperature': unit_conversion('visual_crossing', 'temperature', data_point['temp']),
                    'humidity': data_point['humidity'],
                    'wind_speed': unit_conversion('visual_crossing', 'windspeed', data_point['windspeed']),
                    'wind_direction': data_point['winddir'],
                    'rainfall': unit_conversion('visual_crossing', 'rainfall', data_point['precip']),
                    'irradiance': data_point['solarradiation']
                }
                weather_forecast_records.append(weather_forecast_record)

    records_to_add_or_update = []
    for record in weather_forecast_records:
        defaults = {
            'temperature': record['temperature'],
            'humidity': record['humidity'],
            'wind_speed': record['wind_speed'],
            'wind_direction': record['wind_direction'],
            'rainfall': record['rainfall'],
            'irradiance': record['irradiance']
        }
        date = record['date'].astimezone(pytz.UTC)
        record_data = {
            'time': record['date'],
            'api_name': api_name,
            'weather_station_id': location_id,
            'defaults': defaults
        }
        records_to_add_or_update.append(record_data)

    add_records_to_model(ForecastModel, records_to_add_or_update, api_name, location_name, location_id)


def add_record_to_model(WeatherModel, date, api_name, location_name, location_id, defaults):
    if api_name == 'nms':
        conditions = Q(time=date, **{location_name: location_id})
        field_values = {
            'time': date,
            'weather_station_id': location_id,
            **defaults,
        }
    else:
        conditions = Q(time=date, api_name=api_name, **{location_name: location_id})
        field_values = {
            'time': date,
            'api_name': api_name,
            **defaults,
        }
        field_values[location_name] = location_id

    # Try to find a matching record
    existing_record = WeatherModel.objects.filter(conditions).first()

    # If an existing record was found, update its fields
    if existing_record:
        for field, value in field_values.items():
            setattr(existing_record, field, value)
        existing_record.save()
    else:
        # Create a new record if no matching record exists
        new_record = WeatherModel(**field_values)
        new_record.save()


def add_records_to_model(WeatherModel, records, api_name, location_name, location_id):
    updated_records = []
    new_records = []

    # Generate list of conditions to match against
    condition_list = []
    for record in records:
        date = record.get('time')
        if api_name == 'nms':
            condition_list.append(Q(time=date, **{location_name: location_id}))
        else:
            condition_list.append(Q(time=date, api_name=api_name, **{location_name: location_id}))

    # Fetch all matching records in one query
    existing_records = WeatherModel.objects.filter(reduce(operator.or_, condition_list))

    # Convert queryset to dictionary for easy lookup
    existing_records_dict = {}

    for er in existing_records:
        if api_name == 'nms':
            key = (er.time, getattr(er, location_name))
        else:
            key = (er.time, api_name, getattr(er, location_name))
        existing_records_dict[key] = er

    for record in records:
        date = record.get('time')
        location_id = record.get(location_name)
        defaults = record.get('defaults', {})

        if api_name == 'nms':
            key = (date, location_id)
        else:
            key = (date, api_name, location_id)

        if key in existing_records_dict:
            existing_record = existing_records_dict[key]
            for field, value in defaults.items():
                setattr(existing_record, field, value)
            updated_records.append(existing_record)
        else:
            if api_name == 'nms':
                field_values = {
                    'time': date,
                    'weather_station_id': location_id,
                    **defaults,
                }
            else:
                field_values = {
                    'time': date,
                    'api_name': api_name,
                    **defaults,
                }
                field_values[location_name] = location_id
            new_records.append(WeatherModel(**field_values))

    # Save updated records
    if updated_records:
        WeatherModel.objects.bulk_update(updated_records, list(defaults.keys()))

    # Bulk create new records
    if new_records:
        WeatherModel.objects.bulk_create(new_records)

def get_historical_weather(start, end):
    # Retrieve unique weather API names from the variables model
    api_endpoint = variables.objects.get(group='weather_historical', type='API', name='nms').value
    api_param = variables.objects.get(group='weather_historical', type='parameters', name='nms').value
    TZ = pytz.timezone(settings.TIME_ZONE)

    # Define the increment (7 days)
    increment = timedelta(days=7)
    # Define search values for each weather parameter
    searchvalue2 = {
        'temperature': 10,  # deg_C
        'rainfall': 0,  # mm
        'windspeed': 51,  # knots
        'winddirection': 56,  # degrees
        'riverlevel': 4013,
        'irradiance': 72,  # Watts/m^2
        'humidity': 4007,  # %
    }
    result_df = calc_nearest_weather_station()
    unique_ws = result_df['weather_station_id'].unique()
    for weather_station_id in unique_ws:
        weather_station = WeatherStation.objects.filter(id=weather_station_id).first()
        station_earliest_date = weather_station.data_start
        station_latest_date = weather_station.data_end
        new_records_present = False
        if station_latest_date:
            if station_latest_date < datetime.utcnow().replace(tzinfo=pytz.utc):
                new_records_present = True

        if new_records_present or weather_station.data_count < 20000:
            end_date = convert_time_naive(end, settings.TIME_ZONE, 'UTC')
            start_date = convert_time_naive(start, settings.TIME_ZONE, 'UTC')
            if station_latest_date:
                start_date = station_latest_date.replace(tzinfo=None)

            while end_date >= start_date:
                # Calculate the start and end times for the current week
                endTime = end_date
                startTime = endTime - increment
                if startTime < start_date:
                    startTime = start_date
                if startTime == endTime:
                    break
                for key, value in searchvalue2.items():
                    url, headers = nms_api_call(api_endpoint, api_param, weather_station.iid, startTime, endTime, TZ, value)
                    response = requests.get(url, headers=headers)
                    if response.status_code == 200:
                        weather_data = response.json()
                        parse_nms_data(weather_data, startTime, endTime, TZ, weather_station.id, key)
                    else:
                        print("Error:", response.text)
                        print(url)

                # Move back 7 days for the next iteration
                end_date -= increment
            update_station_statistics(weather_station_id)

    update_station_statistics()

def nms_api_call(endpoint, api_param, weather_station_iid, startTimeUTC, endTimeUTC, TZ, value):
    # Start building the API query URL with the base URL and location
    endTimeUTC = endTimeUTC.replace(minute=55)
    param_dict = {
        "weather_station_iid": weather_station_iid,
        "field": value,
        "start": startTimeUTC,
        "end": endTimeUTC,
    }
    nms_url = endpoint + fill_api_parameters(api_param, param_dict)

    user_agent = "Chrome/116.0.5845.180"
    headers = {"User-Agent": user_agent}

    return nms_url, headers


def unit_conversion(api_name, field_name, value):
    if value != value: #replace nan with null
        return None
    # Dictionary that maps (api_name, field_name) to operations
    conversion_dict = {
        ("visual_crossing", "temperature"): lambda x: (x - 32) *5 / 9,  # Convert F to C
        ("visual_crossing", "windspeed"): lambda x: x * 0.868976,  # convert from mph to knots
        ("visual_crossing", "windspeed"): lambda x: x * 25.4,  # convert from in to mm
        ("meteomatics", "temperature"): lambda x: (x - 32) *5 / 9,  # convert from F to C
        ("meteomatics", "windspeed"): lambda x: x * 1.94384,  # convert from m/s to knots
        ("tomorrow_io", "temperature"): lambda x: (x - 32) * 5 / 9,  # convert from F to C
        ("tomorrow_io", "windspeed"): lambda x: x * 0.868976,  # convert from mph to knots
        ("tomorrow_io", "rainfall"): lambda x: x * 25.4,  # Convert from in/hr to mm/hr
    }

    # Check if the combination of api_name and field_name is in the dictionary
    key = (api_name, field_name)
    if key in conversion_dict:
        # Apply the corresponding operation to the value
        operation = conversion_dict[key]
        result = operation(value)
        return result
    else:
        # If no mapping is found, return the original value
        return value


def parse_nms_data(api_response, startTime, endTime, TZ, weather_station_id, key):
    # Create a DataFrame from your JSON data
    df = pd.DataFrame(api_response["results"])

    # Convert the "date" column to datetime objects
    df["date"] = pd.to_datetime(df["date"])

    # Set the "date" column as the DataFrame index
    df.set_index("date", inplace=True)
    df = df.drop(['station', 'variable'], axis=1)
    # Resample the data by hour
    aggregation_functions = {
        'maxtemp': 'max',
        'temperature': 'mean',
        'mintemp': 'min',
        'rainfall': 'sum',
        'pressure': 'mean',  # hPa
        'windspeed': 'mean',  # knots
        'winddirection': 'mean',  # degrees
        'riverlevel': 'sum',
        'irradiance': 'mean',  # Watts/m^2
        'humidity': 'mean',  # %
    }
    hourly_data = df.resample("H").agg(aggregation_functions[key])
    records_to_add_or_update = []
    for date, record in hourly_data.iterrows():
        # Convert UTC time to the timezone from settings
        date = date.replace(tzinfo=pytz.utc).astimezone(TZ)
        record_data = {
            'time': date,
            'weather_station_id': weather_station_id,
            'defaults': {key: unit_conversion('nms', key, record.iloc[0])}
        }
        records_to_add_or_update.append(record_data)

    add_records_to_model(Weather, records_to_add_or_update, 'nms', 'weather_station_id', weather_station_id)