from PIL import Image, ImageFont, ImageDraw
import glob
import os
import pandas as pd
import numpy as np
from dataset_post_processing import label_images_gap_ticks


def stack_images_of_dataset(dataset_folder, n_rows, n_columns, down_sample_max=np.inf, start=0, image_name=None):
    img_folder = os.path.join(dataset_folder, 'images/*.jpeg')
    label_file = os.path.join(dataset_folder, 'labeled_images.csv')
    if not os.path.isfile(label_file):
        label_images_gap_ticks(dataset_folder)
    label_data = pd.read_csv(label_file)
    img_files = glob.glob(img_folder)

    def img_ticks(img_file_path):
        return int(img_file_path.split('/')[-1].split('.')[0])
    img_files = sorted(img_files, key=img_ticks)

    nb_of_imgs = len(img_files)
    img_texts = None
    r_front = None
    r_status = None
    if any(["range_front[mm]" in key for key in label_data.keys()]):
        r_front = label_data["range_front[mm]"]
    elif any(["range.front" in key for key in label_data.keys()]):
        r_front = label_data["range.front"]
        r_status = label_data["mRange.rangeStatusFront"]
    if r_front is not None:
        img_texts = ['{}mm, status={}'.format(v[0], v[1]) for v in zip(r_front.values, r_status.values)]
    if n_rows*n_columns > (nb_of_imgs - start):
        i = 1
        while (n_rows-i)*(n_columns-i) > (nb_of_imgs - start):
            i += 1
        n_rows = n_rows - i
        n_columns = n_columns - i

        if n_rows != 0 and n_columns != 0:
            imgs_stacked = stack_images(n_rows, n_columns, img_files, down_sample_max, start, label_strings=img_texts, image_name=image_name)
        else:
            print("Not enough images")
    else:
        imgs_stacked = stack_images(n_rows, n_columns, img_files, down_sample_max, start, label_strings=img_texts, image_name=image_name)


def stack_images(n_rows, n_columns, img_files, down_sample_max, start, label_strings=None, image_name=None):
    down_sample = int(np.floor((len(img_files) - start) / (n_rows * n_columns)))
    if down_sample > down_sample_max:
        down_sample = down_sample_max
    print("Down sample rate: {}".format(down_sample))
    img_files = img_files[start::down_sample]
    label_strings = label_strings[start::down_sample]
    image_stacked = None
    fnt = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 16)
    for r in range(0, n_rows):
        img = Image.open(img_files[r * n_columns]).convert('RGB')
        d = ImageDraw.Draw(img)
        d.text((20, 20), "{}".format(r * n_columns), fill=(255, 0, 0), font=fnt)
        if label_strings is not None:
            d.text((20, img.size[1] - 40), label_strings[r * n_columns], fill=(255, 0, 0), font=fnt)
        row = np.array(img)
        for c in range(1, n_columns):
            img = Image.open(img_files[c + r*n_columns]).convert('RGB')
            d = ImageDraw.Draw(img)
            d.text((20, 20), "{}".format(c + r*n_columns), fill=(255, 0, 0), font=fnt)
            if label_strings is not None:
                d.text((20, img.size[1]- 40), label_strings[c + r * n_columns], fill=(255, 0, 0), font=fnt)
            img = np.array(img)
            row = np.hstack((row, img))
        if image_stacked is None:
            image_stacked = row
        else:
            image_stacked = np.vstack((image_stacked, row))

    image_stacked = Image.fromarray(image_stacked)
    if image_name is not None:
        image_stacked.save(image_name)
    image_stacked.show()
    print("Test")

if __name__=="__main__":
    stack_images_of_dataset("../dataset/CiakTest", 1, 3, start=63, down_sample_max=1, image_name="ciakTest.eps")
