import xml.etree.ElementTree as ET
import shutil
import os
import re
import xmlformatter
from utils import check_in_tag
from collections import defaultdict
from variables import PATH_TO_LOCAL_DECODER,PATH_TO_LOCAL_RULES

from remote import Remote

class Xml_handler(Remote):
    """
    A utility class for managing XML configurations. This class provides methods to add group and decoder sections
    to an XML file, ensuring that duplicates are handled appropriately with warnings.
    """

    def __init__(self,**kwargs):
        """
        Initialize the XmlHandler class with the paths to the local rules and decoder XML files.
        
        NOTE local_rules_path (str) and local_decoder_path (str) are reserved to the manager.
        Default Args:
            local_rules_path (str): Path to the local rules XML file.
            local_decoder_path (str): Path to the local decoder XML file.

        Raises:         
            ValueError: If the provided paths are not valid XML files.
        """
        super().__init__(**kwargs)
       
        if self.agent_type == 'manager':
            # Ensure the provided paths are valid XML files and contain a <root> tag
            self.local_rules_xml_path = self._validate_xml_file(PATH_TO_LOCAL_RULES)
            self.local_decoder_xml_path = self._validate_xml_file(PATH_TO_LOCAL_DECODER)
        

    def _validate_xml_file(self, file_path: str) -> str:
        """
        Validate that the provided file path is a valid XML file and contains a <root> tag.
        If the <root> tag is missing, enclose the file with a <root> tag gracefully.

        Args:
            file_path (str): Path to the XML file.

        Returns:
            str: The validated file path.

        Raises:
            ValueError: If the file is not a valid XML file.
        """
    
        if not os.path.isfile(file_path):
            raise ValueError(f"The file {file_path} does not exist.")

        try:
            tree = ET.parse(file_path)
            root = tree.getroot()
            if root.tag != 'root':
                self.logger.warning(f"The file {file_path} does not contain a <root> tag. Enclosing the file with a <root> tag.")
                new_root = ET.Element('root')
                new_root.append(root)
                tree = ET.ElementTree(new_root)
                tree.write(file_path, encoding='utf-8', xml_declaration=True)
        except ET.ParseError:
            self.logger.warning(f"The file {file_path} is not a valid XML file. Enclosing the content with a <root> tag.")
            with open(file_path, 'r') as file:
                content = file.read()

            # Check for any XML declaration
            xml_declaration_match = re.match(r'<\?xml.*\?>', content)
            if xml_declaration_match:
                xml_declaration = xml_declaration_match.group(0)
                new_content = f"{xml_declaration}<root>{content[len(xml_declaration):]}</root>"
            else:
                new_content = f"<root>{content}</root>"

            with open(file_path, 'w') as file:
                file.write(new_content)

        return file_path



    

    def load_xml_file(self, file_path: str) -> ET.Element:
        """
        Load an XML or .conf file (i.e., ossec.conf, local_rules.xml, local_decoder.xml) and return the root element.

        Args:
            file_path (str): Relative path to the ossec.conf or .xml file.

        Returns:
            ET.Element: The root element of the XML tree if successful, or None if an error occurs.
        """
        

        # Assert the file has either .conf or .xml extension
        if not file_path.endswith(('.conf', '.xml')):
            self.logger.error("Invalid file type: Provide a .conf or .xml file")
            return None

        try:
            # If it's a .conf file, copy and rename it to .xml
            if file_path.endswith('.conf'):
                new_file = os.path.splitext(file_path)[0] + '.xml'
                shutil.copy(file_path, new_file)
                file_path = new_file

            # Parse the XML file
            tree = ET.parse(file_path)
            root = tree.getroot()

            return root

        except ET.ParseError as e:
            self.logger.error(f"XML ParseError: {e}")
        except FileNotFoundError as e:
            self.logger.error(f"FileNotFoundError: {e}")
        except Exception as e:
            self.logger.error(f"Unexpected error: {e}")

        return None

    


    def _format_save_xml(self, file_path: str, wazuh_format=False):
        """
        Private method
        Format the ossec.conf XML file
        INFO Formating in wazuh require to remove the header and root tag !
        Args:
            file_path (str): Path to the xml file
        """
        # Use xmlformatter to format the file
        formatter = xmlformatter.Formatter(indent="1", indent_char="\t", preserve=["literal"])
        formatted_file = formatter.format_file(file_path)

        # Save the formatted file without the XML declaration
        xmlformatter.save_formatter_result(formatted_file, formatter, True, input_file=file_path, outfile=file_path)

        # Open the XML file to read its contents
        with open(file_path, 'r') as file:
            content = file.read()

            # Check for XML declaration and remove it if present
            xml_declaration_match = re.match(r'<\?xml.*\?>', content)
            if xml_declaration_match:
                content = content.replace(xml_declaration_match.group(0), '').lstrip()



        # Save the formatted content.
        with open(file_path, 'w') as file:
            file.write(content)

    

    def save_xml_file(self, ossec_conf: ET.Element, file_path: str):
        """
        Save the ossec_conf XML to a .conf or .xml file, handling special cases for .conf files.

        Args:
            ossec_conf (ET.Element): The root element of the ossec.conf XML tree.
            file_path (str): Path to save the ossec.conf or .xml file.
        """
        # Ensure the file has either .conf or .xml extension
        if not file_path.endswith(('.conf', '.xml')):
            self.logger.error("Invalid file type: Provide a .conf or .xml file")
            return

        # Convert the XML tree to an ElementTree object
        tree = ET.ElementTree(ossec_conf)

        file_path_xml = os.path.splitext(file_path)[0] + '.xml'

        try:
            # Directly save to an .xml file with XML declaration
            tree.write(file_path_xml, encoding='utf-8', xml_declaration=True)

            # Apply formatting (if needed) and save it
            self._format_save_xml(file_path_xml)

            if file_path.endswith('.conf'):
                # Copy the .xml file to the .conf file
                shutil.copy(file_path_xml, file_path)

                # Copy the .xml file to the .conf file (after removing the XML declaration)
                self.logger.debug(f"Successfully saved the .conf file at {file_path}")

            else:
                self.logger.debug(f"Successfully saved the .xml file at {file_path}")

        except Exception as e:
            self.logger.error(f"Failed to save the file: {e}")

    # HERE Adders
    
    def add_group_to_xml(self, group_sections_str: str, file_path: str = None):
        """
        Add multiple group sections to the XML file if they don't exist. If they do, add only new rules.

        Args:
            group_sections_str (str): The raw XML string containing multiple group sections to add.
            file_path (str): Path to the XML file.
        Info:
            Reserved to manager
        """

        if self.agent_type != 'manager':
            self.logger.error(f"Agent is not a manager (found: {self.agent_type}), cannot proceed")
            return


        # Default xml file
        file_path = file_path or self.local_rules_xml_path

        root = self.load_xml_file(file_path)
        group_sections = ET.fromstring(f"<root>{group_sections_str}</root>")

        for group_section in group_sections.findall('.//group'):
            group_name = group_section.get('name')

            # Assert that the group section has the minimum required tags
            if group_name is None:
                self.logger.error(f"Group section missing 'name' attribute. Skipping.")
                continue

            rules = group_section.findall('.//rule')
            if not rules:
                self.logger.error(f"Group section '{group_name}' has no rules. Skipping.")
                continue

            for rule in rules:
                rule_id = rule.get('id')
                description = rule.find('description')
                if rule_id is None:
                    self.logger.error(f"Rule in group '{group_name}' missing 'id' attribute. Skipping.")
                    continue
                if description is None or description.text is None:
                    self.logger.error(f"Rule with id {rule_id} in group '{group_name}' missing 'description' tag. Skipping.")
                    continue

            group_exists = False

            for group in root.findall('.//group'):
                if group.get('name') == group_name:
                    group_exists = True
                    for rule in group_section.findall('.//rule'):
                        rule_id = rule.get('id')
                        if not any(r.get('id') == rule_id for r in group.findall('.//rule')):
                            group.append(rule)
                        else:
                            self.logger.warning(f"Rule with id {rule_id} already exists in group {group_name}. Skipping.")
                    break

            if not group_exists:
                root.append(group_section)

        self.save_xml_file(root, file_path)

    def add_decoder_to_xml(self, decoder_sections_str: str, file_path: str = None):
        """
        Add multiple decoder sections to the XML file if they don't exist.

        Args:
            decoder_sections_str (str): The raw XML string containing multiple decoder sections to add.
            file_path (str): Path to the XML file.
        Info:
            Reserved to manager
        """

        if self.agent_type != 'manager':
            self.logger.error(f"Agent is not a manager (found: {self.agent_type}), cannot proceed")
            return

        # Default xml file
        file_path = file_path or self.local_decoder_xml_path

        root = self.load_xml_file(file_path)
        decoder_sections = ET.fromstring(f"<root>{decoder_sections_str}</root>")

        for decoder_section in decoder_sections.findall('.//decoder'):
            decoder_name = decoder_section.get('name')
            decoder_exists = False

            for decoder in root.findall('.//decoder'):
                if decoder.get('name') == decoder_name:
                    decoder_exists = True
                    self.logger.warning(f"Decoder with name {decoder_name} already exists. Skipping.")
                    break

            if not decoder_exists:
                root.append(decoder_section)

        self.save_xml_file(root, file_path)


    # HERE Tags processing 

    def process_element(self, elem):
        """
        Recursively process an element to extract its text, attributes, and nested elements.

        Args:
            elem (ET.Element): The element to process.

        Returns:
            dict: Dictionary containing the extracted data.
        """
        result = {'text': elem.text.strip() if elem.text else None}
        if elem.attrib:
            result.update(elem.attrib)
        for child in elem:
            if child.tag in result:
                if isinstance(result[child.tag], list):
                    result[child.tag].append(self.process_element(child))
                else:
                    result[child.tag] = [result[child.tag], self.process_element(child)]
            else:
                result[child.tag] = self.process_element(child)
        return result

    def print_node_text(self, root: ET.Element):
        """
        Print the text content of each node in the XML tree.

        Args:
            root (ET.Element): The root element of the XML tree.
        """
        for child in root:
            print(f"Tag: {child.tag}, Attributes: {child.attrib}")
            for subchild in child:
                print(f"  Subtag: {subchild.tag}, Text: {subchild.text}")


    # HERE Compress xml functions

    def compress_tree(self, tree, TAG_TO_COMPRESS):
        root = tree.getroot()

        # Find all tags of the specified type (TAG_TO_COMPRESS)
        elements_to_compress = root.findall(f'.//{TAG_TO_COMPRESS}')

        # Filter the elements to keep only those without attributes
        elements_to_compress = [elem for elem in elements_to_compress if len(elem.attrib) == 0]

        # If there are elements to compress
        if elements_to_compress:
            # Collect the text content of all matching elements, join with commas
            compressed_text = ",".join(elem.text.strip() for elem in elements_to_compress if elem.text)

            # Create a mapping of elements to their parent
            parent_map = {c: p for p in root.iter() for c in p}
            parent = parent_map[elements_to_compress[0]]

            # Get the index of the first element to compress
            first_elem_index = list(parent).index(elements_to_compress[0])

            # Safely remove the elements to be compressed
            for elem in elements_to_compress:
                parent_map[elem].remove(elem)

            # Create a new compressed element and insert it where the original tags were
            compressed_element = ET.Element(TAG_TO_COMPRESS)
            compressed_element.text = compressed_text
            parent.insert(first_elem_index, compressed_element)  # Insert at the same location

            self.logger.debug(f"Successfully compressed {TAG_TO_COMPRESS} tags.")
        else:
            self.logger.debug(f"No {TAG_TO_COMPRESS} tags found to compress.")

    def compress_xml(self, input_file, output_file, TAG_TO_COMPRESS):
        """
        Compresses all XML tags of the specified type (TAG_TO_COMPRESS) by merging their text content
        into a single tag at the same position in the XML structure.

        Args:
            input_file (str): Path to the input XML file to be compressed.
            output_file (str): Path to the output XML file where the compressed XML will be saved.
            TAG_TO_COMPRESS (str): The name of the XML tag to compress (e.g., 'ignore').
        """
        self.logger.debug(f"Running compress_xml with input file: {input_file}, output file: {output_file}, tag to compress: {TAG_TO_COMPRESS}")

        # Parse the XML file
        tree = ET.parse(input_file)

        self.compress_tree(tree, TAG_TO_COMPRESS)
        # Write the updated XML to the output file
        tree.write(output_file, encoding='utf-8', xml_declaration=False)
        # Optionally, you can call a function to format the output XML (if needed)
        self._format_save_xml(output_file)
        self.logger.debug(f"Compressed XML saved to {output_file}")

    def compress_elements(self, parent: ET.Element, TAG_TO_COMPRESS: list):
        """
        Compress elements within a parent based on the tags to compress.

        Args:
            parent (ET.Element): The parent element.
            TAG_TO_COMPRESS (list): List of tags to compress.
        """
        for tag in TAG_TO_COMPRESS:
            elements_to_compress = parent.findall(f'.//{tag}')

            # Group elements by their attributes
            grouped_elements = defaultdict(list)
            for elem in elements_to_compress:
                attributes = frozenset(elem.attrib.items())
                grouped_elements[attributes].append(elem)

            # Dictionary to keep track of compressed elements
            compressed_elements = {}

            for attributes, elements in grouped_elements.items():
                all_texts = []
                for elem in elements:
                    if elem.text:
                        all_texts.extend(elem.text.strip().split(','))

                # Remove duplicates and log if found
                unique_texts = sorted(set(all_texts))
                if len(unique_texts) < len(all_texts):
                    self.logger.warning(f"Redundant entries found in tag '{tag}' with attributes {attributes}. Removed duplicates.")

                compressed_text = ",".join(unique_texts)
                if compressed_text:
                    # Check for redundancy
                    if compressed_text in compressed_elements:
                        continue

                    first_elem_index = list(parent).index(elements[0])

                    for elem in elements:
                        parent.remove(elem)

                    compressed_element = ET.Element(tag, attrib=dict(attributes))
                    compressed_element.text = compressed_text
                    parent.insert(first_elem_index, compressed_element)

                    # Add the compressed text to the dictionary to track redundancy
                    compressed_elements[compressed_text] = compressed_element
    
    # HERE Remote function #
    def check_in_tag(self,tag: str,text: str,attributes: dict = None, xml_path  = None) -> bool:
        xml_path = xml_path or self.VM_conf_path
        return  self.run_function_on_remote_host(check_in_tag,[tag,text,xml_path,attributes])
    
    # HERE Tests


    def generate_testing_logs(self, group_sections_str: str) -> dict:
        """
        Generate testing logs for the added rules and decoders in the group sections.

        Args:
            group_sections_str (str): The raw XML string containing multiple group sections.

        Returns:
            dict: A dictionary where the key is the rule ID and the value is the associated testing log.
        """
        # Parse the decoder file
        decoder_tree = ET.parse(self.local_decoder_xml_path)
        decoder_root = decoder_tree.getroot()

        # Create a dictionary to map decoder names to their pre-match patterns
        decoder_map = {}
        for decoder in decoder_root.findall('.//decoder'):
            decoder_name = decoder.get('name')
            prematch = decoder.find('prematch').text if decoder.find('prematch') is not None else ''
            decoder_map[decoder_name] = prematch

        # Parse the group sections
        group_sections = ET.fromstring(f"<root>{group_sections_str}</root>")
        testing_logs = {}

        for group_section in group_sections.findall('.//group'):
            group_name = group_section.get('name')

            for rule in group_section.findall('.//rule'):
                rule_id = rule.get('id')
                rule_description = rule.find('description').text if rule.find('description') is not None else 'No description'
                rule_match = rule.find('match').text if rule.find('match') is not None else ''
                decoded_as = rule.find('decoded_as').text if rule.find('decoded_as') is not None else ''

                # Generate a testing log based on the rule options and decoder
                if rule_match:
                    testing_log = f"{rule_match} {rule_description}"
                    testing_logs[rule_id] = testing_log
                elif decoded_as and decoded_as in decoder_map:
                    prematch = decoder_map[decoded_as]
                    testing_log = f"{prematch} {rule_description}"
                    testing_logs[rule_id] = testing_log
                else:
                    self.logger.warning(f"No match pattern or decoder found for rule ID {rule_id} in group {group_name}. Skipping.")

        return testing_logs


    
    def test_group_sections(self, group_sections_str: str, tmp_path:str):
        """
        Test each group section in the given XML string using wazuh-logtest.

        Args:
            group_sections_str (str): The raw XML string containing multiple group sections to test.
            tmp_path(str) : tmp file to test before saving.
        Returns:
            int: Exit status (0 for success, 1 for failure).
        """
        self.logger.debug("Starting test for group sections.")

        # Generate testing logs
        testing_logs = self.generate_testing_logs(group_sections_str)
        self.logger.debug(f"Generated testing logs: {testing_logs}")

        # Create a temporary file in the dedicated folder
        temp_dir = 'tmp'
        os.makedirs(temp_dir, exist_ok=True)
        input_file_path = os.path.join(temp_dir, 'input_logs.txt')

        try:
            # Write all log tests to the input file
            with open(input_file_path, 'w') as input_file:
                for rule_id, log_test in testing_logs.items():
                    input_file.write(log_test + '\n')
            self.logger.debug(f"Log tests written to {input_file_path}")

            # Get the absolute path and adjust it to reference the /vagrant folder
            absolute_input_path = os.path.abspath(input_file_path)
            self.logger.debug(f"Absolute input path: {absolute_input_path}")

            # Check if '/vagrant' is part of the absolute path
            if '/vagrant' in absolute_input_path:
                # Cut the absolute path to keep the part starting from '/vagrant'
                vagrant_input_path = absolute_input_path.split('/vagrant', 1)[-1]
                vagrant_input_path = '/vagrant' + vagrant_input_path
            else:
                vagrant_input_path = absolute_input_path
            self.logger.debug(f"Vagrant input path: {vagrant_input_path}")

            # Path to the new XML file to test before sync
            path_tmp_xml_file = os.path.join('/vagrant/src', tmp_path)
            self.logger.debug(f"Path to the new XML file: {path_tmp_xml_file}")

            # Run the wazuh-logtest command remotely
            command = [
                '/var/ossec/bin/wazuh-logtest',
                '-l', path_tmp_xml_file,
                '<', vagrant_input_path
            ]
            self.logger.debug(f"Running command: {command}")
            result = self.run_remote_command(command)

            # Analyze the output
            output_content = result.stderr # Considered as an error , anyway i don't know why
            self.logger.debug(f"wazuh-logtest output: {output_content}")

            all_rules_matched = True

            #BUG TO ENSURE IT IS CORRECTED AS RULE ID ISN'T recognized...
            for rule_id, log_test in testing_logs.items():
                if f'id: \'{rule_id}\'' in output_content:
                    self.logger.debug(f"Rule ID {rule_id} reached Phase 3 and matched successfully.")
                else:
                    self.logger.error(f"Rule ID {rule_id} did not match the expected ID.")
                    all_rules_matched = False

            if all_rules_matched:
                self.logger.info("All rules matched successfully.")
                return True
            else:
                self.logger.error("Some rules did not match the expected IDs.")
                return False

        except Exception as e:
            self.logger.error(f"An error occurred during testing: {e}")
            return 1

        finally:
            # Clean up the temporary file
            os.remove(input_file_path)
            self.logger.debug(f"Cleaned up temporary file: {input_file_path}")

    

    def _wazuh_format_xml(self, file_path):
        """
        Format the XML for Wazuh handling.

        Args:
            file_path (str): Path to the XML file.

        Returns:
            str: Path to the formatted XML file.
        """
        try:
            self.logger.debug(f"Formatting XML file at {file_path} for Wazuh handling.")
            formatter = xmlformatter.Formatter(indent="1", indent_char="\t", preserve=["literal"])

            formatted_file = formatter.format_file(file_path)

            # Save the formatted file without the XML declaration
            xmlformatter.save_formatter_result(formatted_file, formatter, True, input_file=file_path, outfile=file_path)

            with open(file_path, 'r') as file:
                content = file.read()
                self.logger.debug(f"Read content from {file_path}.")

                # Parse the XML content
                root = ET.fromstring(content)
                self.logger.debug("Parsed XML content successfully.")

                # Check if the root tag is named <root>
                if root.tag == 'root':
                    self.logger.debug("Root tag is <root>. Extracting inner content.")
                    # Get the inner content of the root tag
                    inner_content = ''.join(ET.tostring(child, encoding='unicode') for child in root)

                    # Create a temporary file in the /scr/tmp directory
                    temp_dir = 'tmp'
                    os.makedirs(temp_dir, exist_ok=True)
                    tmp_file_path = os.path.join(temp_dir, 'formatted_xml.xml')

                    # Write the inner content to the temporary file
                    with open(tmp_file_path, 'w') as tmp_file:
                        tmp_file.write(inner_content)
                    self.logger.debug(f"Inner content written to temporary file at {tmp_file_path}.")

                    return tmp_file_path
                else:
                    self.logger.warning("Root tag is not <root>. No formatting needed but you ask for formatting in wazuh")
                    return file_path

        except Exception as e:
            self.logger.error(f"An error occurred while formatting the XML file: {e}")
            raise

                
    def synchronize_xml_with_VM(self, group_sections_str=None, decoder_sections_str=None, xml_path=None, vm_xml_path=None):
        """
        Synchronize XML configurations with the VM.

        Args:
            group_sections_str (str, optional): The raw XML string containing multiple group sections.
            decoder_sections_str (str, optional): The raw XML string containing multiple decoder sections.
            xml_path (str, optional): Path to the local XML file.
            vm_xml_path (str, optional): Path to the VM XML file.
        """
        try:


            # HERE wazuh formatting sections
            
            if decoder_sections_str:
                # Add the decoder section to the XML file
                self.add_decoder_to_xml(decoder_sections_str)
                self.logger.debug("Decoder sections added to the XML file.")

                self._format_save_xml(self.local_decoder_xml_path,wazuh_format=True)
                
                # Wazuh format 
                tmp_path = self._wazuh_format_xml(self.local_decoder_xml_path)
                # Synchronize the local decoder XML with the VM
                self.synchronize_with_VM(tmp_path, self.VM_local_decoder_path)
                self.logger.debug("Local decoder XML synchronized with the VM.")

            if group_sections_str:
                # Add the group section to the XML file
                self.add_group_to_xml(group_sections_str)
                self.logger.debug("Group sections added to the XML file.")


                #Wazuh format

                tmp_path = self._wazuh_format_xml(self.local_rules_xml_path)

                # Test the group sections
                if self.test_group_sections(group_sections_str,tmp_path):

                    self._format_save_xml(self.local_rules_xml_path,wazuh_format=True)
                    self.logger.info("Group sections tested successfully.")


                    # Synchronize the local rules XML with the VM
                    self.synchronize_with_VM(tmp_path, self.VM_local_rules_path)
                    self.logger.info("Local rules XML synchronized with the VM.")
                    return True
                else: 
                    self.logger.error("Local rules XML file not updated")
                    return False


            else:
                # Synchronize the specified XML paths with the VM
                self.synchronize_with_VM(xml_path, vm_xml_path)
                self.logger.debug(f"XML file at {xml_path} synchronized with the VM at {vm_xml_path}.")

        except Exception as e:
            self.logger.error(f"An error occurred during XML synchronization: {e}")
            raise

