import argparse
import json
import os
import sys
import collections
import collections.abc
import ast
import base64
# --- Compatibility Fix ---
# Monkey-patch ABCs for Python ≥3.10 compatibility
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import time
from urllib.parse import quote

import pandas as pd
import requests
from pcap_ioc import Pcap

# --- Configuration (Global Constants) ---
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from config import (
    VT_API_KEY,
    VT_BASE_URL,
    VT_RATE_LIMIT_DELAY,
    VT_RESULT_FILENAME,
    VT_JSON_LINES_LIMIT
)

# --- VirusTotal Functions ---
def query_virustotal(ioc: str) -> dict:
    """
    Query the VirusTotal API for information about a given IOC (Indicator of Compromise).

    Args:
        ioc (str): A string representation of a dictionary with keys 'type' and 'value' representing the IOC.

    Returns:
        dict: A dictionary containing the original IOC, counts of malicious/suspicious/harmless/undetected detections.
              If malicious or suspicious detections exist, includes the full JSON report under 'full_report'.
              If not found, returns {'ioc': ioc, 'not_found': True}.
              On error, returns {'ioc': ioc, 'error': 'HTTP <status_code>'}.
    """
    ioc = ast.literal_eval(ioc)
    ioc_type = ioc['type']
    ioc_value = ioc['value']

    # Normalize types
    if ioc_type == 'domain':
        ioc_type = 'domains'
    elif ioc_type == 'ip':
        ioc_type = 'ip_addresses'
    elif ioc_type == 'url':
        # ioc_type = 'urls'
        # ioc_value = base64.urlsafe_b64encode(ioc_value.encode()).decode().strip('=')
        return None  # Skip URL analysis

    url = f"{VT_BASE_URL}/{ioc_type}/{ioc_value}"
    headers = {"x-apikey": VT_API_KEY}

    print(f"[*] Querying: {url}")
    resp = requests.get(url, headers=headers)

    print(f"[*] HTTP Response: {resp}")
    if resp.status_code == 200:
        json_data = resp.json()
        attr = json_data.get("data", {}).get("attributes", {})
        stats = attr.get("last_analysis_stats", {})

        result = {
            "ioc": ioc,
            "malicious": stats.get("malicious", 0),
            "suspicious": stats.get("suspicious", 0),
            "harmless": stats.get("harmless", 0),
            "undetected": stats.get("undetected", 0),
        }

        if result["malicious"] > 0 or result["suspicious"] > 0:
            result["full_report"] = json_data
        return result

    elif resp.status_code == 404:
        return {"ioc": ioc, "not_found": True}
    else:
        return {"ioc": ioc, "error": f"HTTP {resp.status_code}"}

def save_vt_results_csv(results: list[dict], output_folder: str):
    """
    Save a list of VirusTotal query results to a CSV file (excluding full JSON reports).

    Args:
        results (list[dict]): List of dictionaries containing VirusTotal query results.
        output_folder (str): Path to the folder where the CSV file will be saved.

    Returns:
        None
    """
    path = os.path.join(output_folder, VT_RESULT_FILENAME)

    # Rimuove 'full_report' prima di salvare
    sanitized_results = []
    for r in results:
        r_copy = {k: v for k, v in r.items() if k != "full_report"}
        sanitized_results.append(r_copy)

    df = pd.DataFrame(sanitized_results)
    df.to_csv(path, index=False)
    print(f"[+] VirusTotal results saved to {path} (without full reports)")

def save_full_json_report(result: dict, output_folder: str):
    """
    Save the full VirusTotal JSON report of a single IOC result to a JSON file.

    Args:
        result (dict): A single query result dictionary which may contain a 'full_report' key.
        output_folder (str): Directory path where the JSON file will be saved.

    Returns:
        None
    """
    full_json = result.get("full_report")
    if not full_json:
        return

    ioc_value = result["ioc"]["value"].replace(":", "_").replace("/", "_")
    filename = os.path.join(output_folder, f"{result['ioc']['type']}_{ioc_value}.json")
    
    with open(filename, "w") as f:
        json.dump(full_json, f, indent=2)
    
    print(f"[+] Saved full report to {filename}")

def print_summary(results: list[dict], output_folder: str):
    """
    Print a summary of malicious or suspicious IOCs and save their full VirusTotal JSON reports.

    Args:
        results (list[dict]): List of VirusTotal query results.
        output_folder (str): Directory where full reports will be saved if applicable.

    Returns:
        None
    """
    print("\n[!] Malicious or suspicious IoCs:")
    for res in results:
        if res.get("malicious", 0) > 0 or res.get("suspicious", 0) > 0:
            print(f"\n--- {res['ioc']} ---")
            print(f"malicious={res.get('malicious', 0)}, suspicious={res.get('suspicious', 0)}")
            full_json = res.get("full_report")
            if full_json:
                lines = json.dumps(full_json, indent=2).splitlines()[:VT_JSON_LINES_LIMIT]
                print("\n".join(lines))
                save_full_json_report(res, output_folder)

# --- IOC Processing Functions ---
def extract_iocs_from_pcap(pcap_file: str) -> set[str]:
    """
    Extract possible Indicators of Compromise (IoCs) from a PCAP file.

    Args:
        pcap_file (str): Path to the PCAP file.

    Returns:
        set[str]: A set of string representations of extracted IoCs.
    """
    p = Pcap(pcap_file)
    indicators = {str(i) for i in p.indicators}
    print(f"[+] Extracted {len(indicators)} possible IoCs from PCAP")
    return indicators

def process_virustotal_results(iocs: set[str], output_folder: str) -> list[dict]:
    """
    Query VirusTotal for each IOC and save the aggregated results to CSV.

    Args:
        iocs (set[str]): Set of stringified IOCs to query.
        output_folder (str): Path to directory where results will be saved.

    Returns:
        list[dict]: List of dictionaries with VirusTotal query results for each IOC.
    """
    print("[*] Querying VirusTotal for each IOC...")
    results = []

    for ioc in iocs:
        res = query_virustotal(ioc)
        
        if res is None:
            continue  # skip gli IOC non gestiti (es. URL)

        if res:
            results.append(res)

        if "error" not in res:
            time.sleep(VT_RATE_LIMIT_DELAY)  # Respect VT rate limiting

    save_vt_results_csv(results, output_folder)
    return results

# --- Main ---
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--source', choices=['misp', 'local'], default='local', help="Source: 'local' or 'misp'")
    parser.add_argument('file', help='PCAP file to process')
    parser.add_argument('--local', help='Path to local IOC list (if source=local)')
    parser.add_argument('--case-folder', required=True, help='Path to case folder')
    args = parser.parse_args()

    vt_report_folder = os.path.join(args.case_folder, "03_processing", "network")
    os.makedirs(vt_report_folder, exist_ok=True)

    iocs = extract_iocs_from_pcap(args.file)
    results = process_virustotal_results(iocs, vt_report_folder)
    print_summary(results, vt_report_folder)

if __name__ == "__main__":
    main()
