import os
import hashlib
import logging
import shutil
from pathlib import Path
import argparse
from datetime import datetime
import subprocess

### ── GLOBAL CONFIGURATION ──
from config import (
    BASE_DIR,
    CASE_DIRS,
    #REPORT_ROLES,
    #ROLE_TEMPLATES,
    CLEAN_PLUGIN_EXAMPLES_TEMPLATES,
    ARTIFACT_EXTENSIONS,
    CREATECASE_LOG_FILE,
    CREATECASE_LOG_FORMAT,
    CREATECASE_LOG_LEVEL,
    PROCESS_ACQUISITION_SCRIPT_PATH
)


### ── UTILITY FUNCTIONS ──

def setup_logging():
    """
    Configures the logging system to output messages both to a file and the console.

    Arguments:
        None

    Returns:
        None
    """
    logging.basicConfig(level=CREATECASE_LOG_LEVEL,
                        format=CREATECASE_LOG_FORMAT,
                        handlers=[logging.FileHandler(CREATECASE_LOG_FILE),
                                  logging.StreamHandler()])


def sha256sum(filepath: Path) -> str:
    """
    Computes the SHA-256 hash of a file.

    Arguments:
        filepath (Path): Path to the file to hash.

    Returns:
        str: The SHA-256 hash of the file as a hexadecimal string.
    """
    h = hashlib.sha256()
    with open(filepath, 'rb') as f:
        for chunk in iter(lambda: f.read(8192), b''):
            h.update(chunk)
    return h.hexdigest()


def run_command(cmd: list, error_msg: str):
    """
    Executes a shell command and logs an error message if it fails.

    Arguments:
        cmd (list): List of command-line arguments to execute.
        error_msg (str): Error message to log if the command fails.

    Returns:
        bool: True if the command succeeded, False otherwise.
    """
    try:
        subprocess.run(cmd, check=True)
    except subprocess.CalledProcessError as e:
        logging.error(f"{error_msg}: {e}")
        return False
    return True


### ── SIGNING AND VERIFICATION FUNCTIONS ──

def sign_file(file_path: Path, key_path: Path):
    """
    Signs a file using a private key (OpenSSL).

    Arguments:
        file_path (Path): Path to the file to sign.
        key_path (Path): Path to the private key used for signing.

    Returns:
        None (Logs success or failure message)
    """
    sig = run_command([
        "openssl", "dgst", "-sha256", "-sign", str(key_path),
        "-out", str(file_path.with_suffix(file_path.suffix + ".sig")),
        str(file_path)
    ], f"❌ Failed to sign {file_path.name}")
    if sig:
        logging.info(f"🔏 Signed {file_path.name}")


def verify_signature(file_path: Path, sig_path: Path, pubkey_path: Path) -> bool:
    """
    Verifies the signature of a file using a public key.

    Arguments:
        file_path (Path): Path to the file to verify.
        sig_path (Path): Path to the signature file (.sig).
        pubkey_path (Path): Path to the public key used for verification.

    Returns:
        bool: True if the signature is valid, False otherwise.
    """
    ok = run_command([
        "openssl", "dgst", "-sha256", "-verify", str(pubkey_path),
        "-signature", str(sig_path), str(file_path)
    ], f"❌ Signature verification failed for {file_path.name}")
    if ok:
        logging.info(f"✅ Verified signature for {file_path.name}")
    return ok


### ── DIRECTORY AND TEMPLATE CREATION FUNCTIONS ──

def generate_role_templates(report_dir: Path):
    """
    Creates template files for each role in the report directory.

    Arguments:
        report_dir (Path): Base directory where role templates will be created.

    Returns:
        None
    """
    for role, files in ROLE_TEMPLATES.items():
        for rel_path, content in files.items():
            file_path = report_dir / role / rel_path
            file_path.parent.mkdir(parents=True, exist_ok=True)
            if not file_path.exists():
                file_path.write_text(content)


def generate_clean_examples(case_dir: Path, os_type: str):
    """
    Generates placeholder files for clean plugin examples.
    """
    clean_dir = case_dir / "03_processing" / "memory" / "clean_examples"
    clean_dir.mkdir(parents=True, exist_ok=True)

    clean_plugin_examples = CLEAN_PLUGIN_EXAMPLES_TEMPLATES.get(os_type, [])

    for example in clean_plugin_examples:
        plugin_name = example["pluginname"]
        content = example["out"]

        # create a file named after the plugin
        file_path = clean_dir / f"clean_{plugin_name}.txt"
        file_path.write_text(content.strip() + "\n", encoding="utf-8")

def create_case_dirs(root: Path):
    """
    Creates the necessary directory structure for a forensic case.

    Arguments:
        root (Path): Root directory where all subdirectories will be created.

    Returns:
        None
    """
    for d in CASE_DIRS:
        (root / d).mkdir(parents=True, exist_ok=True)


### ── FILE COPY, HASHING, AND METADATA ──

def copy_and_hash_artifacts(src_dir: Path, dst_dir: Path, skip_copy: bool = False, no_hash: bool = False) -> dict:
    """
    Copies forensic artifacts from source to destination (unless skip_copy is True),
    categorizes them, and computes their SHA-256 hashes.

    Arguments:
        src_dir (Path): Directory containing the original artifacts.
        dst_dir (Path): Directory where categorized artifacts will be copied.
        skip_copy (bool): If True, do not copy files, only calculate hashes from source.

    Returns:
        dict: Dictionary mapping filenames (copied or source) to their SHA-256 hashes.
    """
    
    logging.info(f"🔍 Scanning artifact directory: {src_dir}")

    if not skip_copy:
        dst_dir.mkdir(parents=True, exist_ok=True)
        logging.debug(f"Created or confirmed destination directory: {dst_dir}")
    
    if no_hash:
        logging.info("Skipping hashing of artifacts as per user request.")
        return {}
    hashes = {}
    for file in src_dir.iterdir():
        if not file.is_file() or file.name.startswith("."):
            continue

        l = file.name.lower()
        if any(key in l for key in ARTIFACT_EXTENSIONS["disk"]):
            target_name = "hdd_image_flat"
        elif any(key in l for key in ARTIFACT_EXTENSIONS["memory"]):
            target_name = "memdump"
        elif any(key in l for key in ARTIFACT_EXTENSIONS["network"]):
            target_name = "net_traffic"
        else:
            logging.warning(f"⚠️ Unrecognized artifact type, skipping: {file.name}")
            continue

        if skip_copy:
            h = sha256sum(file)
            hashes[file.name] = h
            logging.info(f"Calculated hash for {file.name} (hash={h}) [skip_copy=True]")
        else:
            target = dst_dir / f"{target_name}{file.suffix}"
            logging.info(f"📄 Copying '{file.name}' to '{target.name}' as type '{target_name}'")
            shutil.copy2(file, target)
            h = sha256sum(target)
            hashes[target.name] = h
            logging.info(f"Copied {file.name} -> {target.name} (hash={h})")
    
    return hashes


def write_metadata(root: Path, case_id: str, os_type: str,
                   hashes: dict, zip_file=None, signing_key=None):
    """
    Writes case metadata including hashes, case information, and (optionally) a signed manifest.

    Arguments:
        root (Path): Root case directory.
        case_id (str): Unique identifier for the case.
        os_type (str): Operating system type (e.g., Windows, Linux).
        hashes (dict): Dictionary of file hashes to include in the manifest.
        zip_file (str, optional): Optional malware sample archive to include.
        signing_key (str, optional): Private key path used to sign the manifest.

    Returns:
        None
    """
    
    meta = root / "00_meta"
    checksums = root / "02_raw_checksums"
    for path in [meta, checksums]:
        path.mkdir(parents=True, exist_ok=True)

    logging.info(f"📝 Writing metadata to: {meta}")
    (meta / "case_info.txt").write_text(
        f"Case ID: {case_id}\nOS: {os_type}\nDate: {datetime.now()}\n"
    )

    if zip_file and Path(zip_file).exists():
        logging.info(f"Copying malware sample to: {meta}")
        z = Path(zip_file)
        shutil.copy2(z, meta / "malware_sample.zip")
        h = sha256sum(z)
        (checksums / "malware_sample_hash.sha256").write_text(h)


    logging.info(f"📝 creating manifest in: {checksums}")
    lines = []
    for fname, h in hashes.items():
        lines.append(f"{h}  {fname}")
    manifest = checksums / "manifest.txt"
    manifest.write_text("\n".join(sorted(lines)) + "\n")
    if signing_key:
        sign_file(manifest, Path(signing_key))


### ── Main Flow ──

def main():
    setup_logging()
    
    logging.info("🚀 Starting forensic case creation script.")
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--signing-key", default=None, help="Path to private key for signing manifest")
    parser.add_argument("--verify-key", default=None, help="Path to public key for verifying manifest")
    parser.add_argument("-y", "--yes", action="store_true", help="Automatically proceed with processing without prompting")
    parser.add_argument("case_id", help="Unique identifier for the case")
    parser.add_argument("hash_id", help="Name or hash for the malware sample")
    parser.add_argument("artifact_dir", help="Directory containing artifacts")
    parser.add_argument("os_type", choices=["windows", "linux"], help="Operating system type")
    parser.add_argument("zip_file", nargs="?", default=None, help="Optional zip file containing malware samples")
    parser.add_argument("--skip_copy", action="store_true", help="Skip copying artifacts")
    parser.add_argument("--local", action="store_true", help="Run AI locally without internet connection")
    parser.add_argument("--no_hash", action="store_true", help="Skip hashing artifacts")
    args = parser.parse_args()
    
    logging.debug(f"Parsed arguments: {args}")

    artifact_dir = Path(args.artifact_dir)
    if not artifact_dir.exists():
        logging.error(f"{artifact_dir} does not exist")
        return

    if args.skip_copy:
        logging.info("Skipping artifact copy as per user request.")
        hashes = copy_and_hash_artifacts(artifact_dir, artifact_dir, skip_copy=True, no_hash=args.no_hash)
    else:
        tmp = BASE_DIR / args.os_type / f"{args.case_id}_tmp"
        hashes = copy_and_hash_artifacts(artifact_dir, tmp, skip_copy=False, no_hash=args.no_hash)
        # if not hashes:
        #     logging.error("No artifacts. Aborting.")
        #     return

    final = BASE_DIR / args.os_type / f"{args.case_id}_{args.hash_id}"
    create_case_dirs(final)
    #generate_role_templates(final / "05_reports")
    generate_clean_examples(final,args.os_type)

    if not args.skip_copy:
        acq = final / "01_acquisition"
        logging.info("📦 Moving artifacts to final acquisition directory")
        
        acq.mkdir(exist_ok=True)
        for f in tmp.iterdir():
            shutil.move(str(f), acq / f.name)
        logging.info(f"🧹 Cleaning up temporary directory: {tmp}")    
        shutil.rmtree(tmp)
        
        if not args.no_hash:
            # 🔍 Verifying integrity
            logging.info("🔍 Recomputing hashes at final destination")
            recomputed = {f.name: sha256sum(f) for f in acq.iterdir()}
            
            if set(hashes.items()) != set(recomputed.items()):
                logging.error("❌ Mismatch between original and recomputed hashes! Details:")
                for fname in set(hashes) | set(recomputed):
                    orig = hashes.get(fname)
                    new = recomputed.get(fname)
                    if orig != new:
                        logging.error(f"  {fname}: original={orig}, recomputed={new}")
                raise ValueError("Hash mismatch between original and recomputed files.")
            else:
                logging.info("✅ Original and recomputed hashes match for all files.")
            
            logging.debug(f"Computed hashes: {recomputed}")
    
    logging.info("📝 Writing metadata for the case")
    write_metadata(final, args.case_id, args.os_type, hashes,
                   zip_file=args.zip_file, signing_key=args.signing_key)

    logging.info(f"✅ Case successfully created at: {final}")
    
    # verifying manifest signature if requested
    if args.verify_key:
        manifest = final / "02_raw_checksums" / "manifest.txt"
        sig = manifest.with_suffix(".txt.sig")
        if sig.exists():
            verify_signature(manifest, sig, Path(args.verify_key))

    if args.yes:
        logging.info(f"🛠️ Automatically executing processing script on {str(final)}")
        if args.local:
            subprocess.run(["python3", PROCESS_ACQUISITION_SCRIPT_PATH, str(final), "--os", args.os_type, "--artefacts", str(artifact_dir), "--local"])
        else:
            subprocess.run(["python3", PROCESS_ACQUISITION_SCRIPT_PATH, str(final), "--os", args.os_type, "--artefacts", str(artifact_dir)])
    else:
        if input("Proceed with processing? [y/N]: ").lower() == "y":
            if args.local:
                subprocess.run(["python3", PROCESS_ACQUISITION_SCRIPT_PATH, str(final), "--os", args.os_type, "--artefacts", str(artifact_dir), "--local"])
            else:
                subprocess.run(["python3", PROCESS_ACQUISITION_SCRIPT_PATH, str(final), "--os", args.os_type, "--artefacts", str(artifact_dir)])

if __name__ == "__main__":
    main()
