import random
import os
import sys
import shutil
import csv
import logging
import subprocess
from pathlib import Path
from collections import defaultdict

from dfvfs.resolver import resolver
import argparse
from dfvfs.helpers.volume_scanner import (
    VolumeScanner,
    VolumeScannerOptions,
    VolumeScannerMediator
)
from dfvfs.helpers.file_system_searcher import (
    FileSystemSearcher,
    FindSpec
)
from Registry import Registry

# ========== CONFIGURATION ==========
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from config import (
    REGISTRY_KEYS_TO_EXTRACT
)


def extract_registry_keys(hive_path: Path, output_txt: Path):
    """
    Extract autorun-related registry keys from a hive file and save them to output_txt.
    Uses python-registry (Registry).
    """
    try:
        reg = Registry.Registry(str(hive_path))
    except Exception as e:
        logger.warning(f"Failed to open registry hive {hive_path}: {e}")
        return

    try:
        with open(output_txt, "w", encoding="utf-8") as out:
            out.write(f"Registry autorun keys extracted from {hive_path.name}\n")
            out.write("=" * 60 + "\n\n")
            for key_path in REGISTRY_KEYS_TO_EXTRACT:
                try:
                    key = reg.open(key_path)
                except Exception as e_open:
                    # If the exception is the specific "key not found", skip quietly,
                    # otherwise log a warning and continue.
                    exc_type = type(e_open)
                    # python-registry exposes RegistryKeyNotFoundException as an attribute
                    # off the Registry module (Registry.RegistryKeyNotFoundException).
                    if hasattr(Registry, "RegistryKeyNotFoundException") and isinstance(e_open, Registry.RegistryKeyNotFoundException):
                        # key missing — skip
                        continue
                    # Otherwise, unknown error opening key — log and continue
                    logger.warning(f"Error opening key {key_path} in {hive_path}: {e_open}", exc_info=False)
                    continue

                out.write(f"[{key_path}]\n")
                try:
                    for val in key.values():
                        name = val.name() or "(Default)"
                        try:
                            val_data = val.value()
                        except Exception:
                            val_data = "<unreadable>"
                        out.write(f"  {name} = {repr(val_data)}\n")
                except Exception as e_vals:
                    logger.warning(f"Failed enumerating values for {key_path} in {hive_path}: {e_vals}", exc_info=False)

                out.write("\n")
        logger.debug(f"Written registry extraction to {output_txt}")
    except Exception as e:
        logger.error(f"Failed to write registry output {output_txt}: {e}", exc_info=True)


# =======================
# GLOBAL CONFIGURATION
# =======================
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from config import (
    GEMINI_AGENT,
    OUTPUT_SUBDIR_DISK,
    EXTRACTION_MAP_FILENAME,
    LINUX_ARTIFACT_PATHS,
    WINDOWS_ARTIFACT_PATHS,
    SCANDISK_LOG_LEVEL
)

os_type = None

# =======================
# LOGGING CONFIGURATION
# =======================
logging.basicConfig(
    level=SCANDISK_LOG_LEVEL,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# =======================
# CLASSES
# =======================
class CustomMediator(VolumeScannerMediator):
    """
    Custom mediator for volume scanning, allowing partition selection.
    """
    def GetPartitionIdentifiers(self, volume_system, volume_identifiers):
        logger.debug(f"Getting partition identifiers from volume system type '{volume_system.type_indicator}'")
        partitions = []
        for identifier in volume_identifiers:
            scan_node = volume_system.GetSubNodeByLocation(f"/{identifier}")
            extent = getattr(scan_node.path_spec, "extent", 0)
            partitions.append((identifier, extent))
            logger.debug(f"Found partition '{identifier}' with extent {extent}")

        for identifier, _ in partitions:
            if volume_system.type_indicator.lower() in ("ntfs",):
                logger.info(f"Volume system type is NTFS, selecting partition '{identifier}'")
                return [identifier]

        largest_partition = max(partitions, key=lambda tup: tup[1])[0]
        logger.info(f"No NTFS partition found, selecting largest partition '{largest_partition}'")
        return [largest_partition]

# =======================
# FUNCTIONS
# =======================
def copy_file_entry_to_local(file_entry, local_output_path: Path) -> bool:
    """Copies a file from a file_entry to local path, and parses registry hives if applicable."""
    try:
        logger.debug(f"Copying file '{file_entry.name}' to '{local_output_path}'")
        local_output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(local_output_path, "wb") as local_file:
            file_object = file_entry.GetFileObject()
            while (buf := file_object.read(4096)):
                local_file.write(buf)
            file_object.close()
        logger.debug(f"Successfully copied '{file_entry.name}' to '{local_output_path}'")

        # --- If it's a registry hive, extract autorun keys ---
        if file_entry.name.upper() in ("NTUSER.DAT", "SOFTWARE", "SYSTEM"):
            txt_out = local_output_path.with_suffix(".txt")
            extract_registry_keys(local_output_path, txt_out)
            logger.info(f"Registry autorun keys extracted to '{txt_out}'")

        return True
    except Exception as e:
        logger.error(f"Error copying '{file_entry.name}' → '{local_output_path}': {e}", exc_info=True)
        return False


def detect_os_and_artifacts(fs_type: str):
    """Detects OS type and returns relevant artifact paths."""
    fs_type = fs_type.lower()
    logger.debug(f"Detecting OS and artifacts for file system type '{fs_type}'")
    if "ext" in fs_type or "linux" in fs_type:
        logger.info("Detected Linux file system")
        return "linux", LINUX_ARTIFACT_PATHS
    elif "ntfs" in fs_type or "windows" in fs_type:
        logger.info("Detected Windows file system")
        return "windows", WINDOWS_ARTIFACT_PATHS
    logger.warning(f"Unknown file system type '{fs_type}'")
    return None, []


def process_volume(volume_path_spec, output_dir: Path, extraction_map: list, base_output: Path):
    """Extracts artifacts from a single partition and saves them locally."""
    global os_type
    logger.debug(f"Processing volume with type '{volume_path_spec.type_indicator}'")
    os_type, artifact_paths = detect_os_and_artifacts(volume_path_spec.type_indicator)
    if not artifact_paths:
        logger.warning("No artifact paths found, skipping volume")
        return

    try:
        file_system = resolver.Resolver.OpenFileSystem(volume_path_spec)
        searcher = FileSystemSearcher(file_system, volume_path_spec)

        for artifact_glob in artifact_paths:
            logger.debug(f"Searching for artifacts with pattern '{artifact_glob}'")
            find_spec = FindSpec(location_glob=artifact_glob)
            matches = searcher.Find(find_specs=[find_spec])

            for match in matches:
                file_entry = file_system.GetFileEntryByPathSpec(match)
                logger.debug(f"Matched: {file_entry.path_spec.location}")
                if file_entry is None:
                    continue

                def recurse_copy(entry, current_path, rel_path=""):
                    try:
                        if entry.IsFile():
                            safe_path = Path(output_dir) / rel_path / entry.name
                            safe_path.parent.mkdir(parents=True, exist_ok=True)

                            # Prevent file overwrite by adding directory context if file exists
                            if safe_path.exists():
                                # Modify the file name to avoid overwriting
                                if current_path[0] == '\\':
                                 # Normalize to backslash
                                 diff=current_path.split('\\')
                                elif current_path[0] == '/':
                                 diff=current_path.split('/')
                                safe_path = safe_path.with_name(f"{safe_path.stem}_{diff[-2]}{safe_path.suffix}")
                            
                            if copy_file_entry_to_local(entry, safe_path):
                                extraction_map.append(
                                    (str(current_path), str(safe_path.relative_to(base_output)))
                                )
                        elif entry.IsDirectory():
                            # check if directory already exists
                            safe_path = Path(output_dir) / rel_path / entry.name

                            # if it already exists, modify the name
                            if safe_path.exists():
                                logger.info(f"Directory '{safe_path}' already exists, modifying name")
                                #generate random number and seed
                                random.seed()
                                diff = random.randint(1000, 9999)
                                safe_path = safe_path.with_name(f"{safe_path.stem}_{diff}{safe_path.suffix}")

                            safe_path.mkdir(parents=True, exist_ok=True)
                              
                            # logging
                            logger.info(f"Created directory '{safe_path}'")
                            # Add the slash at the end of the path
                            if current_path[0] == '\\':
                                # Normalize to backslash
                                diff='\\'
                            elif current_path[0] == '/':
                                diff='/'
                            extraction_map.append(
                                (str(current_path)+diff, str(safe_path.relative_to(base_output))+diff)
                            )
                            for sub_entry in entry.sub_file_entries:
                                recurse_copy_no_saving(
                                    sub_entry,
                                    f"{current_path}/{sub_entry.name}",
                                    str(safe_path)
                                )
                    except Exception as e:
                        logger.warning(f"Failed to copy '{current_path}': {e}", exc_info=True)

                def recurse_copy_no_saving(entry, current_path, rel_path=""):
                    try:
                        if entry.IsFile():
                            safe_path = Path(output_dir) / rel_path / entry.name
                            safe_path.parent.mkdir(parents=True, exist_ok=True)
                            copy_file_entry_to_local(entry, safe_path)
                        elif entry.IsDirectory():
                            for sub_entry in entry.sub_file_entries:
                                recurse_copy_no_saving(
                                    sub_entry,
                                    f"{current_path}/{sub_entry.name}",
                                    f"{rel_path}/{entry.name}" if rel_path else entry.name
                                )
                    except Exception as e:
                        logger.warning(f"Failed to copy '{current_path}': {e}", exc_info=True)

                recurse_copy(file_entry, match.location)

        logger.info(f"Extraction for {os_type} artifacts completed in '{output_dir}'")

    except Exception as e:
        logger.error(f"Cannot process volume: {e}", exc_info=True)


def normalize_path_slashes(path: str) -> str:
    """
    Normalize the slashes in a given path, ensuring consistency.
    If the path starts with '\\', it will standardize to '\\'.
    If the path starts with '/', it will standardize to '/'.
    """
    if path:
        # Detect the first type of slash in the path
        if path[0] == '\\':
            # Normalize to backslash
            return path.replace('/', '\\')
        elif path[0] == '/':
            # Normalize to forward slash
            return path.replace('\\', '/')
    return path  # Return unchanged if no slashes to fix

def save_extraction_map(extraction_map: list, output_dir: Path):
    """
    Saves artifact extraction map to CSV.
    Every entry in extraction_map is saved individually (no folder grouping).
    """
    csv_file = Path(output_dir) / EXTRACTION_MAP_FILENAME
    try:
        logger.debug(f"Saving extraction map to '{csv_file}'")

        with open(csv_file, "w", encoding="utf-8") as f:
            f.write("original_path,local_relative_path\n")
            for original, rel in extraction_map:
                # Normalize both original and relative paths
                original_normalized = normalize_path_slashes(str(original))
                rel_normalized = normalize_path_slashes(str(rel))
                f.write(f"{original_normalized},{rel_normalized}\n")

        logger.info(f"Extraction map saved to '{csv_file}'")
    except Exception as e:
        logger.error(f"Error saving extraction map: {e}", exc_info=True)


def run_gemini_ai(disk_image: str, extraction_csv: str, outputdir: Path, local: bool = False):
    """Runs Gemini AI agent on extracted artifacts."""
    logger.info(f"Running {GEMINI_AGENT} for extracted artifacts")
    os.makedirs(outputdir, exist_ok=True)

    try:
        with open(extraction_csv, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                # if it is a registry hive, skip it
                if row["original_path"].upper().endswith(("NTUSER.DAT", "SOFTWARE", "SYSTEM")):
                    continue
                original_path = row["original_path"]

                # --- Build safe filename for report ---
                safe_name = original_path.strip("/\\").replace("/", "_").replace("\\", "_")
                if not safe_name:
                    safe_name = "root"

                # --- Build command ---
                if original_path.endswith("/") or original_path.endswith("\\"):
                    report_file = outputdir / f"{safe_name}.md"
                    cmd = ["python3", GEMINI_AGENT, "-i", disk_image, "-f", original_path, "-o", str(report_file)]
                else:
                    report_file = outputdir / f"{safe_name}.json"
                    cmd = ["python3", GEMINI_AGENT, "-i", disk_image, "-e", original_path, "-o", str(report_file)]

                if local:
                    cmd.append("--local")
                
                logger.debug(f"Executing: {' '.join(cmd)}")
                try:
                    subprocess.run(cmd, check=True)
                    logger.info(f"Report file (if suspicious): {report_file}")
                except subprocess.CalledProcessError as e:
                    logger.error(f"{GEMINI_AGENT} failed for {original_path}: {e}")
    except Exception as e:
        logger.error(f"Failed to read extraction map {extraction_csv}: {e}", exc_info=True)


def extract_artifacts(disk_image: str, case_folder: str, output_base_dir: Path, local: bool = False):
    """Main artifact extraction workflow."""
    logger.info(f"Starting extraction for image '{disk_image}'")
    scanner = VolumeScanner(mediator=CustomMediator())
    options = VolumeScannerOptions()
    options.partitions = ["all"]
    base_path_specs = scanner.GetBasePathSpecs(disk_image, options)

    if not base_path_specs:
        logger.warning("No useful file system found in the image.")
        return

    os.makedirs(output_base_dir, exist_ok=True)
    extraction_map = []

    for volume_path_spec in base_path_specs:
        process_volume(volume_path_spec, output_base_dir, extraction_map, Path(output_base_dir))

    save_extraction_map(extraction_map, output_base_dir)

    extraction_csv = output_base_dir / EXTRACTION_MAP_FILENAME

    # copy all .txt files from output_base_dir to ai reports folder
    ai_reports_dir = output_base_dir / "00 - agent_reports"
    ai_reports_dir.mkdir(parents=True, exist_ok=True)
    for txt_file in output_base_dir.glob("*.txt"):
        shutil.copy(str(txt_file), ai_reports_dir / txt_file.name)

        
    run_gemini_ai(disk_image, extraction_csv, output_base_dir / "00 - agent_reports", local=local)

    logger.info("Artifact extraction + Gemini AI analysis completed")

# =======================
# MAIN
# =======================
def main():
    parser = argparse.ArgumentParser(description="Extract artifacts from disk image and run Gemini AI analysis.")
    parser.add_argument("disk_image", help="Path to the disk image file")
    parser.add_argument("case_folder", help="Case folder for output")
    parser.add_argument("--local", action="store_true", help="Run Gemini agent in local mode")

    args = parser.parse_args()

    disk_image_path = args.disk_image
    case_folder = args.case_folder

    output_base_dir = Path(case_folder) / OUTPUT_SUBDIR_DISK
    os.makedirs(output_base_dir, exist_ok=True)

    extract_artifacts(disk_image_path, case_folder, output_base_dir, sys.argv[3] == "--local" if len(sys.argv) > 3 else False)

    

if __name__ == "__main__":
    main()