#!/usr/bin/env python3
"""
Minimal DFVFS partition finder with root detection
Can optionally print the contents of a file within the detected root.
"""

import sys
import os
import logging
from dfvfs.helpers import volume_scanner
from plaso.cli.storage_media_tool import StorageMediaToolVolumeScanner, StorageMediaToolVolumeScannerOptions
from dfvfs.resolver import resolver
from dfvfs.path import factory as path_spec_factory
from dfvfs.path.ext_path_spec import EXTPathSpec
from dfvfs.path.ntfs_path_spec import NTFSPathSpec
from dfvfs.lib import definitions
import pytsk3

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../")))

from config import (
    CATFILE_LOG_LEVEL,
    MAX_OUTPUT,
    CHUNK_SIZE
)

LOG_LEVEL = CATFILE_LOG_LEVEL

# -------------------------------
# MEDIATOR
# -------------------------------
class MyMediator(volume_scanner.VolumeScannerMediator):
    def GetPartitionIdentifiers(self, volume_system, volume_identifiers):
        return volume_identifiers

    def GetAPFSVolumeIdentifiers(self, volume_system, volume_identifiers):
        return volume_identifiers

    def GetLVMVolumeIdentifiers(self, volume_system, volume_identifiers):
        return volume_identifiers

    def GetVSSStoreIdentifiers(self, volume_system, volume_identifiers):
        return volume_identifiers

    def UnlockEncryptedVolume(self, source_scanner_object, scan_context, locked_scan_node, credentials):
        return True

# -------------------------------
# ROOT DETECTION
# -------------------------------
def is_linux_root(scan_node):
    try:
        fs = resolver.Resolver.OpenFileSystem(scan_node.path_spec)
        root_dir = fs.GetFileEntryByPath("/") if hasattr(fs, "GetFileEntryByPath") else fs.GetFileEntryByPathSpec(
            path_spec_factory.Factory.NewPathSpec(fs.type_indicator, location="/", parent=scan_node.path_spec)
        )
        if not root_dir:
            return False
        entries = {entry.name for entry in root_dir.sub_file_entries}
        return {'usr', 'var', 'etc', 'bin'}.issubset(entries)
    except Exception:
        return False

def find_linux_root(scan_node):
    if scan_node.type_indicator in ('TSK', 'EXT') and is_linux_root(scan_node):
        return scan_node
    for child in scan_node.sub_nodes:
        found = find_linux_root(child)
        if found:
            return found
    return None

def is_windows_root(scan_node):
    try:
        fs = resolver.Resolver.OpenFileSystem(scan_node.path_spec)
        if fs.type_indicator in ('NTFS', 'FAT'):
            root_dir = fs.GetFileEntryByPathSpec(scan_node.path_spec)
            if root_dir:
                entries = {entry.name.lower() for entry in root_dir.sub_file_entries}
                return any(d in entries for d in ('windows', 'program files', 'users'))
        return False
    except Exception:
        return False

def find_windows_root(scan_node):
    if scan_node.type_indicator in ('NTFS', 'FAT') and is_windows_root(scan_node):
        return scan_node
    for child in scan_node.sub_nodes:
        found = find_windows_root(child)
        if found:
            return found
    return None

# -------------------------------
# FILE DISPLAY
# -------------------------------


def print_limited_content_stream(data_stream):
    """Try to print data as clean text first, fallback to byte-wise display."""
    try:
        text = data_stream.decode("utf-8")
        if len(text) > MAX_OUTPUT:
            text = text[:MAX_OUTPUT] + "..."
        print(text)
    except UnicodeDecodeError:
        # Fallback to original byte-wise representation
        output = []
        count = 0
        prev_byte = None
        repeat_count = 0

        for b in data_stream:
            char = chr(b) if 32 <= b <= 126 else f"\\x{b:02x}"

            if prev_byte is None:
                prev_byte = char
                repeat_count = 1
            elif char == prev_byte:
                repeat_count += 1
            else:
                if repeat_count > 3:
                    chunk = f"{prev_byte}*{repeat_count}*"
                else:
                    chunk = prev_byte * repeat_count

                if count + len(chunk) > MAX_OUTPUT:
                    remaining = MAX_OUTPUT - count
                    output.append(chunk[:remaining] + "...")
                    break
                else:
                    output.append(chunk)
                    count += len(chunk)

                prev_byte = char
                repeat_count = 1

        if count < MAX_OUTPUT and repeat_count:
            if repeat_count > 3:
                chunk = f"{prev_byte}*{repeat_count}*"
            else:
                chunk = prev_byte * repeat_count

            remaining = MAX_OUTPUT - count
            output.append(chunk[:remaining])
            if len(chunk) > remaining:
                output.append("...")

        print("".join(output))


def show_file_contents_limited(root_scan_node, relative_path):
    fs = resolver.Resolver.OpenFileSystem(root_scan_node.path_spec)

    if fs.type_indicator == "NTFS":
        img = resolver.Resolver.OpenFileObject(root_scan_node.path_spec.parent)

        class Img_Info(pytsk3.Img_Info):
            def __init__(self, file_object):
                self._file_object = file_object
                super().__init__(url="")

            def read(self, offset, size):
                self._file_object.seek(offset)
                return self._file_object.read(size)

            def get_size(self):
                current = self._file_object.tell()
                self._file_object.seek(0, 2)
                size = self._file_object.tell()
                self._file_object.seek(current)
                return size

        tsk_img = Img_Info(img)
        tsk_fs = pytsk3.FS_Info(tsk_img)
        relative_path_tsk = relative_path.replace("\\", "/")

        try:
            file_obj = tsk_fs.open(relative_path_tsk)
        except Exception as e:
            print(f"[ERROR] File not found: {relative_path_tsk} -> {e}")
            return

        offset = 0
        while offset < file_obj.info.meta.size:
            chunk_size = min(CHUNK_SIZE, file_obj.info.meta.size - offset)
            data = file_obj.read_random(offset, chunk_size)
            print_limited_content_stream(data)
            break  # stop after reaching MAX_OUTPUT
            offset += chunk_size

    else:  # EXT
        path_spec = EXTPathSpec(location=relative_path, parent=root_scan_node.path_spec)
        file_entry = fs.GetEXTFileEntryByPathSpec(path_spec)
        if not file_entry:
            print(f"[ERROR] File not found: {relative_path}")
            return

        # Check if it's a symbolic link
        link_target = file_entry.symbolic_link_target
        if link_target:
            print(f"[INFO] The file is a symlink: {relative_path} -> {link_target}")
            return

        data = file_entry.read(CHUNK_SIZE)
        print_limited_content_stream(data)


# -------------------------------
# MAIN EXECUTION
# -------------------------------
def main():
    if len(sys.argv) < 2 or len(sys.argv) > 3:
        print(f"Usage: {sys.argv[0]} <source_path> <relative_path>")
        sys.exit(1)

    logging.basicConfig(level=LOG_LEVEL)
    source_path = sys.argv[1]
    file_to_show = None

    file_to_show = sys.argv[2]

    mediator = MyMediator()
    options = StorageMediaToolVolumeScannerOptions()
    scanner = StorageMediaToolVolumeScanner(mediator=mediator)
    scan_context = scanner.ScanSource(source_path, options, base_path_specs=[])

    if not scan_context:
        print("[ERROR] Failed to scan source.")
        sys.exit(1)

    root_node = scan_context.GetRootScanNode()

    # Try Linux root first
    main_partition_node = find_linux_root(root_node)
    if main_partition_node:
        pass
        #print(f"[INFO] Linux root partition found: {main_partition_node.path_spec}")
    else:
        # Otherwise try Windows
        main_partition_node = find_windows_root(root_node)
        if main_partition_node:
            pass
            #print(f"[INFO] Windows root partition found: {main_partition_node.path_spec}")
        else:
            print("[ERROR] No suitable root partition found.")
            sys.exit(1)

    # Show file if requested
    if file_to_show:
        show_file_contents_limited(main_partition_node, file_to_show)

if __name__ == "__main__":
    main()