import pandas as pd
import os
import numpy as np
import glob


def label_images(dataset_path):
    state_labels_file = "state_labels_DroneState.csv"
    os.listdir(dataset_path)
    state_labels = pd.read_csv(os.path.join(dataset_path, state_labels_file))
    def key(image_name):
        return int(image_name.split('.')[0])
    image_names = sorted(os.listdir(os.path.join(dataset_path, "images")), key=key)
    image_timestamps_ms = [int(img_name.split('.')[0]) / 1.0e3 for img_name in image_names]
    minimum_distance_idxs = []
    for image_timestamp in image_timestamps_ms:
        minimum_distance_idxs.append(np.argmin(abs(state_labels["timestamp [ms]"].values - image_timestamp)))

    close_state_samples = state_labels.iloc[minimum_distance_idxs, :].reset_index(drop=True)
    img_names_df = pd.DataFrame(image_names, columns=["timestamp_us.jpeg"])
    labeled_images = pd.concat([img_names_df, close_state_samples], axis=1)
    labeled_images.to_csv(os.path.join(dataset_path, "labeled_images.csv"), index=False)


def label_images_gap_ticks_old(dataset_path):
    state_labels_files = glob.glob(os.path.join(dataset_path, "state_labels*.csv"))

    def key(image_name):
        return int(image_name.split('.')[0])
    image_names = sorted(os.listdir(os.path.join(dataset_path, "images")), key=key)
    image_ticks = [int(img_name.split('.')[0]) for img_name in image_names]
    labeled_images = pd.DataFrame(image_names, columns=["timeTicks.jpeg"])
    for state_labels_file in state_labels_files:
        state_labels = pd.read_csv(state_labels_file)
        if "ticks" in state_labels.columns.values:
            state_labels = state_labels.rename(columns={"ticks": "timeTicks"})
            print(state_labels.columns)
        minimum_distance_idxs = []
        for image_tick in image_ticks:
            minimum_distance_idxs.append(np.argmin(abs(state_labels["timeTicks"].values - image_tick)))
        close_state_samples = state_labels.iloc[minimum_distance_idxs, :].reset_index(drop=True)
        confName = state_labels_file.split("_")[-1].split('.')[0]
        close_state_samples = close_state_samples.rename(columns={"timeTicks": confName + "_TimeTicks"})
        labeled_images = pd.concat([labeled_images, close_state_samples], axis=1)
    labeled_images.to_csv(os.path.join(dataset_path, "labeled_images.csv"), index=False)


def label_images_gap_ticks(dataset_path):
    state_labels_files = glob.glob(os.path.join(dataset_path, "state_labels*.csv"))

    def key(image_name):
        return int(image_name.split('.')[0])
    image_names = sorted(os.listdir(os.path.join(dataset_path, "images")), key=key)
    image_ticks = [int(img_name.split('.')[0]) for img_name in image_names]
    labeled_images = pd.DataFrame(image_names, columns=["timeTicks.jpeg"])
    for state_labels_file in state_labels_files:
        state_labels = pd.read_csv(state_labels_file)
        if "ticks" in state_labels.columns.values:
            state_labels = state_labels.rename(columns={"ticks": "timeTicks"})
            print(state_labels.columns)
        for col in state_labels.columns.values:
            if col != "timeTicks":
                col_data = state_labels[["timeTicks", col]]
                min_timestamp_idx = np.concatenate(([0], np.where(abs(np.diff(col_data[col])) > 0)[0] + 1))
                min_timestamp = col_data.iloc[min_timestamp_idx]
                minimum_distance_idxs = []
                for image_tick in image_ticks:
                    minimum_distance_idxs.append(np.argmin(abs(min_timestamp["timeTicks"].values - image_tick)))
                close_state_samples = min_timestamp.iloc[minimum_distance_idxs, :].reset_index(level=0, inplace=False)
                confName = state_labels_file.split("_")[-1].split('.')[0]
                close_state_samples = close_state_samples.rename(columns={"timeTicks": "{}_{}_TimeTicks".format(confName, col)})
                labeled_images = pd.concat([labeled_images, close_state_samples], axis=1)
    labeled_images.to_csv(os.path.join(dataset_path, "labeled_images.csv"), index=False)


if __name__ == "__main__":
    dataset_path = r"../dataset/CiakTest"
    label_images_gap_ticks(dataset_path)
