import numpy as np
from scipy.spatial import distance as dist
from collections import OrderedDict

class Centroids():


    # Class init function (max_disapp_frames is the number of
    # consecutive frames before the centroid is deregistered)
    def __init__(self, max_disapp_frames=30):

        # "centroid_id" is used like an index and count the current number of centroids
        # "centroids" and "disapp_centroids" are two OrderedDict objects (the order the 
        # items are inserted is remembered and used when creating an iterator) and 
        # are used to map centroids and to record those that are currently disappeared 
        self.centroids_id = 0
        self.centroids = OrderedDict()
        self.disapp_centroids = OrderedDict()
        self.max_disappearance_frames = max_disapp_frames
        self.available_ids = []  # NEWList to store available IDs


    # Function that allows to unregister a centroid 
    # using its ID, deleting it from both dictionaries
    def unregister_centroid(self, ID):

        del self.centroids[ID]
        del self.disapp_centroids[ID]
        self.available_ids.append(ID)  # Add the ID to the pool
    
    
    # Function that allows to record a new centroid 
    # with the next available object ID.
    # The new centroid is set as not disappeared
    def record_new_centroid(self,centroid):
        if len(self.available_ids) > 0:
            new_id = self.available_ids.pop(0)  # Reuse an available ID
        else:
            new_id = self.centroids_id
            self.centroids_id += 1
        self.centroids[new_id] = centroid
        self.disapp_centroids[new_id] = 0
        #self.centroids[self.centroids_id] = centroid
        #self.centroids_id += 1
        #self.disapp_centroids[self.centroids_id-1] = 0 """


    # Case management function where there are no bounding 
    # boxes in input: loops on all the old centroids, setting 
    # them as "disappeared" and check if their eventual
    # deregistration is needed
    def empty_case(self):

        for i in list(self.disapp_centroids.keys()):
            self.disapp_centroids[i] += 1
            if self.disapp_centroids[i] >= self.max_disappearance_frames:
                self.unregister_centroid(i)
    

    # Function that loop over all the bounding box rectangles and 
    # use the bounding box coordinates to derive the centroids
    def center_calculator(self,bxs, ret):

        for (i, (x1, y1, x2, y2)) in enumerate(bxs):
            cX = int((x1 + x2) / 2.0)
            cY = int((y1 + y2) / 2.0)
            ret[i] = (cX, cY)
        return ret


    # Function called from centroids_recalculator() that allows to
    # recalculate the centroids in the case there are new centroids
    # in input. It also check if some of these objects have
    # potentially disappeared.
    def modify_centroids(self,inputcentrs):

        # grab the list of centroids IDs and corresponding centroids
        centrIDs = list(self.centroids.keys())
        centrs = list(self.centroids.values())
            
        # Random exemple of output "D" given C1,C2,C3(old centroids) 
        # and N1,N2(new centroids). The values are Euclidean distances
        #
        #        N1    N2
        #    C1  2     4
        #    C2  1     6
        #    C3  9     1
        #
        D = dist.cdist(np.array(centrs), inputcentrs)

        # The following two assignments are used to create a vector that 
        # represents a ranking among the old centroid The best old centroid 
        # is the one that has the shortest distance from any new one. The 
        # first selects the minimum value for each centroid and the second  
        # sorts them in ascending order of Euclidean distance. 
        # In our example: best_centroid = [C2,C3,C1] 
        min_row = D.min(axis=1)
        rank_best_centroids = min_row.argsort()

        # The following assignment creates a new vector, always keeping the rank 
        # among the best old centroids, which indicates which is the nearest new 
        # centroid for each of the old centroids. 
        # In our example: best_new_for_each_old = [N1, N2, N1] with index (C2,C3,C1)
        best_new_for_each_old = D.argmin(axis=1)[rank_best_centroids]

        # The following part of code loops simultaneously on all the old centroids sorted
        # and assigns the new one nearest to it. If we have already examined either the
        # old or the new, ignores it, otherwise graps the object ID for the current 
        # old, set the new and reset the disappeared counter.
        # In our example: (Iter1) C2->N1,  (Iter2) C3->N2, (Iter3) C1->miss!
        usedOlds = set()
        usedNews = set()

        for (old, new) in zip(rank_best_centroids, best_new_for_each_old):

            if old in usedOlds or new in usedNews:
                continue
            centroidID = centrIDs[old]
            self.centroids[centroidID] = inputcentrs[new]
            self.disapp_centroids[centroidID] = 0				
            usedOlds.add(old)
            usedNews.add(new)
        # Finally, if the number of old centroids is greater than the number of new
        # ones, the code checks if you need to delete some of them, if the number of
        # new ones is greater than the number of old ones, then it records the ones 
        # left unassigned. In our example: #old > #new  -> disapp_centroids[C1]++ 
        if D.shape[0] > D.shape[1]:
            unusedOlds =  set(range(0, D.shape[0])).difference(usedOlds)

            for old in unusedOlds:
                centroidID = centrIDs[old]
                self.disapp_centroids[centroidID] += 1

                if self.disapp_centroids[centroidID] > self.max_disappearance_frames:
                    self.unregister_centroid(centroidID)  

        else:        
            unusedNews = set(range(0, D.shape[1])).difference(usedNews)

            for new in unusedNews:
                self.record_new_centroid(inputcentrs[new])



    # Function that allows to update the list of centroids:
    # having in input the bounding boxes of all the new objects of
    # the image, it recalculates the new centroids that are closest
    # to the old ones and returns them.
    def centroids_recalculator(self, boxes):

        # It manages the case in which there are no detected objects
        if len(boxes) == 0:
            self.empty_case()
            return self.centroids

        # It initializes an array of input centroids (x,y), with the lenght 
        # as long as the number of incoming objects
        input_centroids = np.zeros((len(boxes), 2), dtype="int")

        # It transoforms all boxes rectangles in centers
        input_centroids = self.center_calculator(boxes,input_centroids) 

        # If it is not tracking any objects take the input
        # centroids and registers each of them
        if len(self.centroids) == 0:
            for i in range(0, len(input_centroids)):
                self.record_new_centroid(input_centroids[i])
            return self.centroids

        # Otherwise it updates all the centroids, possibly
        # creating new ones and eliminating expired ones.
        self.modify_centroids(input_centroids)
        return self.centroids
