import jetson_inference
import jetson_utils
import time
import cv2
import numpy as np 
import tkinter as tk
from tkinter import ttk
from tkinter import messagebox
import math
import json
import argparse
import imutils

def Follow_Road_Core(send_activate, send_command, send_stop, truncate, low_b, high_b):
    
    net = jetson_inference.detectNet("best_model.onnx", threshold=0.9, 
                input_blob="input_0", output_cvg="scores", output_bbox="boxes", labels="labels.txt") #0.65

    cap = cv2.VideoCapture(0)
    cap.set(3, 160)
    cap.set(4, 120)

    max_steering_angle = 30  # Limita l'angolo massimo di sterzata
    stop_detected = False   # Stato: segnale Stop attivo
    stop_start_time = 0     # Tempo in cui lo Stop è stato rilevato
    stop_cooldown = 20       # Durata dello stop in secondi (modifica questo valore in base alle esigenze)

    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                continue

            # Elaborazione per il riconoscimento dei segnali (prioritario)
            detections = detection(frame, net, truncate)
            if not stop_detected and detections:
                for detect in detections:
                    item = net.GetClassDesc(detect.ClassID).strip()
                    if item == "Stop":
                        print("Stop Sign Detected!")
                        send_stop()  # Invia immediatamente il comando di Stop
                        stop_detected = True
                        stop_start_time = time.time()
                        # Esce subito dalla gestione dei segnali per evitare ulteriori comandi in questo frame
                        break

            # Se siamo in stato di Stop, attendi il cooldown e NON eseguire il tracking
            if stop_detected:
                if time.time() - stop_start_time < stop_cooldown:
                    cv2.imshow("Frame", frame)
                    if cv2.waitKey(1) & 0xff == ord('q'):
                        break
                    continue  # Salta il tracking finché non scade il cooldown
                else:
                    print("Stop cooldown terminato, riprendo il tracking della linea.")
                    stop_detected = False
                    send_activate()  # Riattiva il sistema

            # Elaborazione per il tracking della linea
            hsv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
            # Soglie per rilevare la linea (green line)
            high_b = np.uint8([38, 71, 125])
            low_b = np.uint8([77, 255, 255])
            mask = cv2.inRange(hsv_frame, high_b, low_b)
            contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[-2]

            if contours:
                c = max(contours, key=cv2.contourArea)
                M = cv2.moments(c)
                if M["m00"] != 0:
                    cx = int(M["m10"] / M["m00"])
                    cy = int(M["m01"] / M["m00"])
                    # Calcola l'errore rispetto al centro dell'immagine
                    image_center_x = frame.shape[1] / 2
                    error_x = cx - image_center_x
                    steering_angle = (error_x / image_center_x) * max_steering_angle
                    print(f"angle: {steering_angle}")

                    if abs(error_x) < 10:  # Centro
                        print("On track forward!")
                        send_activate()
                        send_command("forward")
                    elif error_x < -10:  # Sinistra
                        print("Turning left!")
                        send_activate()
                        send_command("left", int(-1 * steering_angle))
                    elif error_x > 10:  # Destra
                        print("Turning right!")
                        send_activate()
                        send_command("right", int(steering_angle))

                    cv2.circle(frame, (cx, cy), 5, (255, 255, 255), -1)
                cv2.drawContours(frame, [c], -1, (0, 255, 0), 1)

            cv2.imshow("Mask", mask)
            cv2.imshow("Frame", frame)

            if cv2.waitKey(1) & 0xff == ord('q'):
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()
    
    return

def detection(frame, net, truncate):

    height=frame.shape[0]
    width=frame.shape[1]

    frame_color = cv2.cvtColor(frame, cv2.COLOR_BGR2RGBA).astype(np.float32)
    frame_cuda = jetson_utils.cudaFromNumpy(frame_color)

    detections=net.Detect(frame_cuda, width, height)

    matching_detections= []
    rects = []
    all_objects= []

    for detect in detections:

        ID = detect.ClassID
        #confidence = truncate(float(detect.Confidence),2)
        confidence = detect.Confidence
        top = int(detect.Top)
        left = int(detect.Left)
        bottom = int(detect.Bottom)
        right = int(detect.Right)
        item = net.GetClassDesc(ID)
        item = item.strip()  # Rimuove spazi e caratteri invisibili come \r e \n
        box = (left,top,right,bottom)

        matching_detections.append(detect)
        box_stop = (left,top,right,bottom)   
        rects.append(box_stop)
        #cv2.putText(frame, item+" "+str(confidence), (left,top+20), font, .75, (0,0,255), 2)    
        cv2.putText(frame, f"{item} ({confidence:.2f})", (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        cv2.rectangle(frame,(left,top),(right,bottom),(0,255,0),2)
        
    return matching_detections