from concurrent.futures import ThreadPoolExecutor
from flask import Flask, request, jsonify
import kubernetes.client
from kubernetes.client.rest import ApiException
from kubernetes import config
from datetime import datetime
import sqlite3
import os
import requests
import threading
import socket
import json

app_verify = Flask(__name__)
app_collect = Flask(__name__)
#app_forward = Flask(__name__)


from flask import request, jsonify
from threading import Lock, Event

responses = []
responses_lock = Lock()
all_responses_received = Event()
expected_responses = 0

def get_node_ip(pod_uuid):
    try:
        # Use load_kube_config for local development
        config.load_kube_config()
        v1 = kubernetes.client.CoreV1Api()
        
        # List all pods in the default namespace and find the one with the matching UUID
        pods = v1.list_namespaced_pod(namespace="default")
        #print(pods)
        pod_name = None
        for pod in pods.items:
            if pod.metadata.uid == pod_uuid:
                pod_name = pod.metadata.name
                break
        
        if not pod_name:
            print(f"Pod with UUID {pod_uuid} not found")
            return None

        print(f"Attempting to read pod: {pod_name} in namespace: default")
        pod = v1.read_namespaced_pod(name=pod_name, namespace="default")
        node_name = pod.spec.node_name
        node = v1.read_node(name=node_name)
        for address in node.status.addresses:
            #print(address)
            if address.type == "InternalIP":
                return address.address
    except ApiException as e:
        print(f"Exception when calling CoreV1Api->read_namespaced_pod: {e}")
    except Exception as e:
        print(f"Error connecting to Kubernetes API: {e}")
    return None

def check_verifier_id(verifier_id):
    conn = sqlite3.connect('verifier_pod_mapping.db')
    cursor = conn.cursor()
    cursor.execute("SELECT 1 FROM verifier_pod_mapping WHERE verifier_id=?", (verifier_id,))
    result = cursor.fetchone()
    conn.close()
    return result is not None

def check_verifier_pod_mapping(verifier_id, pod_id):
    conn = sqlite3.connect('verifier_pod_mapping.db')
    cursor = conn.cursor()
    cursor.execute("SELECT 1 FROM verifier_pod_mapping WHERE verifier_id=? AND pod_uuid=?", (verifier_id, pod_id))
    result = cursor.fetchone()
    conn.close()
    return result is not None
    
@app_verify.route('/verify', methods=['POST'])
def verify():
    global expected_responses
    data = request.get_json()
    verifier_id = data.get('verifier_id')
    pods = data.get('pods', [])

    expected_responses = len(pods)

    if not verifier_id or not check_verifier_id(verifier_id):
        return jsonify({"message": "Unauthorized"}), 401

    # Avvia thread di forward che aspetta tutte le risposte
    threading.Thread(target=wait_and_forward_tcp, args=(
        verifier_id, expected_responses, all_responses_received, responses, responses_lock
    )).start()

    results = []

    for pod in pods:
        pod_id = pod.get('pod_id')
        whitelist = pod.get('whitelist', [])
        exclude_list = pod.get('exclude_list', [])

        if pod_id:
            if not check_verifier_pod_mapping(verifier_id, pod_id):
                return jsonify({"message": "Not authorized to attest some pods"}), 401

            node_ip = get_node_ip(pod_id)

            if node_ip:
                # agent_ip = pod.get("agent_ip") not used in this context
                agent_port = pod.get("agent_port")
                agent_version = pod.get("agent_version", "2.2")

                nonce = data["nonce"]
                mask = data["mask"]
                partial = data.get("partial", "0")
                ima_ml_entry = data.get("ima_ml_entry", "0")

                agent_url = (
                    f"http://{node_ip}:{agent_port}/v{agent_version}/quotes/integrity"
                    f"?nonce={nonce}&mask={mask}&partial={partial}&ima_ml_entry={ima_ml_entry}&pod_id={pod_id}"
                )

                try:
                    r = requests.get(agent_url, timeout=10)
                    r.raise_for_status()
                    quote_data = r.json()
                except requests.RequestException as e:
                    quote_data = {"error": f"Failed to get quote: {e}"}

                results.append({
                    "pod_id": pod_id,
                    "node_ip": node_ip,
                    "whitelist": whitelist,
                    "exclude_list": exclude_list,
                    "quote": quote_data
                })
            else:
                results.append({
                    "pod_id": pod_id,
                    "error": "Node IP not found"
                })
        else:
            results.append({
                "error": "Pod ID missing"
            })

    return jsonify({"message": "Verification completed", "results": results}), 200

@app_collect.route('/collect', methods=['POST'])
def collect():
    global expected_responses
    data = request.get_json()

    # Estrai la quote dal payload JSON, se presente
    response_json = data.get("response") 
    if not response_json:
        return jsonify({"error": "Missing 'response' field"}), 400

    with responses_lock:
        responses.append(response_json)
        if len(responses) == expected_responses:
            all_responses_received.set()

    return jsonify({"message": "Response collected"}), 200



def wait_and_forward_tcp(verifier_id, expected_responses, 
    all_responses_received, responses, responses_lock):
    print(f"Wainting for {expected_responses} response...")

    if not all_responses_received.wait(timeout=10):
        print("Timeout waiting for all response.")
        return

    with responses_lock:
        if not responses:
            print("No valid response.")
            return

        quote_response = responses[0]  # full response JSON from Rust
        responses.clear()
        all_responses_received.clear()

    host = "localhost"  # Assuming the verifier TCP server is running on localhost
    port = 9999

    try:
        with socket.create_connection((host, port), timeout=5) as sock:
            data = json.dumps(quote_response).encode()
            sock.sendall(data)
            print("Quote send to the verifier by TCP")

            response = sock.recv(4096)
            print("Response from verifier:", response.decode())

    except Exception as e:
        print(f"Error to send: {e}")

def run_app(app, port):
    app.run(host='0.0.0.0', port=port, threaded=True)

if __name__ == '__main__':
    with ThreadPoolExecutor(max_workers=3) as executor:
        executor.submit(run_app, app_verify, 5000)
        executor.submit(run_app, app_collect, 5001)