import logging
import json
import csv
import time
from pathlib import Path
import pandas as pd
from google.genai.errors import ClientError, ServerError
from google import genai
import argparse
import os
import sys
import uuid
import requests
import re

# =========================
# CONFIGURABLE VARIABLES
# =========================
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from config import (
    REPORT_FOLDER,
    DISK_SUBFOLDER,
    MEMORY_SUBFOLDER,
    NETWORK_SUBFOLDER,
    SUPPORTED_EXTENSIONS,
    MAX_TPM,
    MAX_RETRIES,
    RETRY_WAIT,
    GEMINI_API_KEY,
    GEMINI_API_KEY_2,
    GEMINI_API_KEY_3,
    GEMINI_API_KEY_4,
    FINAL_GEMINI_MODEL_NAME,
    OLLAMA_BASE_URL,
    OLLAMA_MODEL,
    MAIN_FORENSIC_INSTRUCTION,
    CHUNK_PROMPT_TEMPLATE,
    FINAL_PROMPT_TEMPLATE
)

# =========================
# SETUP LOGGING
# =========================
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)

# Initialize Gemini
#os.environ.setdefault(GEMINI_API_KEY, "DEFAULT_GEMINI_API_KEY")
GEMINI_KEYS = [GEMINI_API_KEY, GEMINI_API_KEY_2, GEMINI_API_KEY_3, GEMINI_API_KEY_4]
current_key_index = 0
client = genai.Client(api_key=GEMINI_KEYS[current_key_index])

# =========================
# FUNCTION DEFINITIONS
# =========================

def chunk_by_words(input_data, max_chunk_size):
    words = input_data.split()
    chunks, current_chunk = [], []

    current_length = 0
    for word in words:
        if current_length + len(word) + 1 > max_chunk_size:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0
        current_chunk.append(word)
        current_length += len(word) + 1
    if current_chunk:
        chunks.append(" ".join(current_chunk))
    return chunks

def generate_toc(markdown: str) -> str:
    """
    Generate a Markdown Table of Contents from headings in the given text.
    """
    lines = []
    for line in markdown.splitlines():
        match = re.match(r'^(#{1,6}) (.*)', line)
        if not match:
            continue
        level = len(match.group(1))   # number of '#' = heading level
        title = match.group(2).strip()

        # Create anchor (GitHub-compatible)
        anchor = re.sub(r'[^\w\- ]', '', title).lower().replace(" ", "-")
        indent = "  " * (level - 1)
        lines.append(f"{indent}- [{title}](#{anchor})")
    return "\n".join(lines)

def collect_files(folder_path: Path, extensions=SUPPORTED_EXTENSIONS):
    if not folder_path.exists():
        logging.warning(f"Folder not found: {folder_path}")
        return []

    files = [f for f in folder_path.rglob("*") if f.suffix.lower() in extensions]
    logging.info(f"Found {len(files)} files in {folder_path}")
    return files

def load_and_format_file(file_path: Path):
    try:
        if(file_path.suffix == ".md" or file_path.suffix == ".json" or file_path.suffix == ".txt"):
            with open(file_path, "r", encoding="utf-8") as f:
                content = f.read()
            return content

        elif file_path.suffix == ".csv":
            df = pd.read_csv(file_path)
            csv_text = df.to_csv(index=False)
            return f"```csv\n{csv_text}\n```"

        else:
            logging.warning(f"Unsupported file type: {file_path}")
            return ""
    except Exception as e:
        logging.error(f"Error loading {file_path}: {e}")
        return ""

def build_report(case_folder: Path):
    report_lines = ["# Forensic Report", "\n## Table of Contents\n"]

    analysis_types = {
        "Disk Analysis": case_folder / DISK_SUBFOLDER,
        "Memory Analysis": case_folder / MEMORY_SUBFOLDER,
        "Network Analysis": case_folder / NETWORK_SUBFOLDER
    }

    for analysis_name, path in analysis_types.items():
        report_lines.append(f"## {analysis_name}")
        files = collect_files(path)
        if not files:
            report_lines.append("_No analysis found_\n")
            continue

        for ext in SUPPORTED_EXTENSIONS:
            # Skip JSON collection in Network folder
            if analysis_name == "Network Analysis" and ext == ".json":
                continue

            ext_files = [f for f in files if f.suffix.lower() == ext]
            if not ext_files:
                continue

            ext_name = ext.replace(".", "").upper() + " Reports"
            report_lines.append(f"### {ext_name}")
            # if analysis is Memory and ext is .md, rename to "MD Reports Memory"
            if analysis_name == "Memory Analysis" and ext == ".md":
                report_lines[-1] = "### MD Reports Memory"
            for file_path in ext_files:
                logging.info(f"Processing {file_path}")
                content = load_and_format_file(file_path)
                content = "\n".join(["> " + line for line in content.splitlines()])
                report_lines.append(f"**{file_path.name}**\n\n{content}\n")

    return "\n".join(report_lines)

def generate_summary(report_content: str, local=False):
    global client, current_key_index, MAX_RETRIES, RETRY_WAIT, OLLAMA_MODEL, OLLAMA_BASE_URL
    MAX_CHUNK_SIZE = 5000  # Approximate max words per chunk
    ROLLING_SUMMARY_COUNT = 3  # Number of previous summaries to include as context
    retries = MAX_RETRIES
    delay = RETRY_WAIT
    model = OLLAMA_MODEL
    fast = OLLAMA_MODEL
    base_url = OLLAMA_BASE_URL
    headers = {"Content-Type": "application/json"}

    # Add timestamp + UUID prefix to make the prompt unique
    unique_prefix = f"[RequestID: {int(time.time())}-{uuid.uuid4().hex[:8]}]\n"

    chunk_prompt_template = CHUNK_PROMPT_TEMPLATE
    final_prompt_template = FINAL_PROMPT_TEMPLATE
    main_instruction = MAIN_FORENSIC_INSTRUCTION
    input_data = report_content

    prompt = main_instruction + input_data

    if(len(prompt) * 1.3 >= MAX_TPM and not local):
        logging.warning(f"Prompt exceeds {MAX_TPM} token limit, truncating...")
        prompt = prompt[:125000]

    for attempt in range(1, MAX_RETRIES + 1):
        try:
            if(local):
                # Use Ollama API
                # Step 1: Chunk input if too long
                if len(input_data) > MAX_CHUNK_SIZE:
                    print(f"[INFO] Input length {len(input_data)} exceeds {MAX_CHUNK_SIZE}, chunking...")
                    chunks = chunk_by_words(input_data, MAX_CHUNK_SIZE)
                    summaries = []

                    # Step 2: Process each chunk
                    for i, chunk_text in enumerate(chunks, start=1):
                        context = ""
                        if summaries:
                            recent_summaries = "\n".join(summaries[-ROLLING_SUMMARY_COUNT:])
                            context = f"Previous summaries:\n{recent_summaries}\n"

                        # Build chunk prompt
                        if chunk_prompt_template:
                            chunk_prompt = chunk_prompt_template.format(
                                chunk_number=i,
                                total_chunks=len(chunks),
                                chunk_text=chunk_text,
                                context=context
                            )
                        # Retry logic per chunk
                        for attempt in range(retries):
                            try:
                                if fast:
                                    payload = {"model": fast, "prompt": chunk_prompt, "stream": False}
                                else:
                                    payload = {"model": model, "prompt": chunk_prompt, "stream": False}
                                response = requests.post(f"{base_url}/api/generate", headers=headers, data=json.dumps(payload), timeout=300)
                                response.raise_for_status()
                                result = response.json()
                                summary_text = result.get("response", "")
                                summaries.append(f"[Chunk {i}] {summary_text}")
                                break
                            except requests.exceptions.RequestException as e:
                                print(f"[ERROR] Chunk {i} attempt {attempt+1} failed: {e}")
                                if attempt < retries - 1:
                                    time.sleep(delay)
                                else:
                                    raise RuntimeError(f"All retries failed for chunk {i}") from e

                    # Step 3: Final synthesis with main instruction
                    if len(summaries) > 1 or main_instruction:
                        if final_prompt_template:
                            final_prompt = final_prompt_template.format(
                                number_of_summaries=len(summaries),
                                summaries="\n".join(summaries),
                                main_instruction=main_instruction or "- Synthesize these into a coherent summary."
                            )

                        for attempt in range(retries):
                            try:
                                print(f"[DEBUG] Request: \n\n {final_prompt}\n\n")
                                payload = {"model": model, "prompt": final_prompt, "stream": False}
                                response = requests.post(f"{base_url}/api/generate", headers=headers, data=json.dumps(payload), timeout=300)
                                response.raise_for_status()
                                print("[DEBUG] Response: \n\n", response.json().get("response", ""), "\n\n")
                                return response.json().get("response", "").strip()
                            except requests.exceptions.RequestException as e:
                                print(f"[ERROR] Final response attempt {attempt+1} failed: {e}")
                                if attempt < retries - 1:
                                    time.sleep(delay)
                                else:
                                    raise RuntimeError("All retries failed for final synthesis") from e
                    else:
                        return {"response": summaries[0]}
                else:
                    for attempt in range(retries):
                        try:
                            payload = {"model": model, "prompt": main_instruction+input_data, "stream": False}
                            print(f"[DEBUG] Request: \n\n {main_instruction+input_data}\n\n")
                            response = requests.post(f"{base_url}/api/generate", headers=headers, data=json.dumps(payload))
                            response.raise_for_status()
                            print("[DEBUG] Response: \n\n", response.json().get("response", ""), "\n\n")
                            return response.json().get("response", "").strip()
                        except requests.exceptions.RequestException as e:
                            print(f"[ERROR] Final response attempt {attempt+1} failed: {e}")
                            if attempt < retries - 1:
                                time.sleep(delay)
                            else:
                                raise RuntimeError("All retries failed for final synthesis") from e

            else:
                # Use Gemini API
                response = client.models.generate_content(
                    model=FINAL_GEMINI_MODEL_NAME,
                    contents=[{"role": "user", "parts": [{"text": prompt}]}],
                    config={"temperature": 0.3}
                )
                print(response)
                if response.text:
                    logging.info("Gemini analysis completed successfully.")
                    return response.text.strip()
                else:
                    logging.error("Empty response from Gemini LLM.")
                    raise RuntimeError("Empty response from Gemini LLM")
                        
        except Exception as e:
            # Check if it's a rate-limit / quota error
            error_str = str(e)
            print(error_str)
            # Handle daily quota exceeded
            if "GenerateRequestsPerDayPerProjectPerModel" in error_str:
                if current_key_index < len(GEMINI_KEYS) - 1:
                    current_key_index += 1
                    new_key = GEMINI_KEYS[current_key_index]
                    logging.warning(f"Quota exceeded. Switching to API key index {current_key_index}.")
                    client = genai.Client(api_key=new_key)
                else:
                    logging.error("All Gemini API keys exhausted.")
                    return "_Gemini analysis failed._"

            # Handle transient errors (RESOURCE_EXHAUSTED, 429, 500)
            elif any(x in error_str for x in ["RESOURCE_EXHAUSTED", "429", "500", "503"]):
                for attempt in range(MAX_RETRIES):
                    logging.warning(f"Quota/server error. Retry {attempt+1}/{MAX_RETRIES} in {RETRY_WAIT}s...")
                    time.sleep(RETRY_WAIT)
            else:
                logging.error("Gemini analysis failed after maximum retries.")
                return "_Gemini analysis failed._"

# =========================
# MAIN SCRIPT
# =========================
def main():
    parser = argparse.ArgumentParser(description="Generate forensic report for a case folder.")
    parser.add_argument("case_folder", type=str, help="Path to the case folder")
    parser.add_argument("--local", action="store_true", help="Use local Ollama model instead of Gemini API")
    args = parser.parse_args()

    case_folder = Path(args.case_folder)
    report_folder = case_folder / REPORT_FOLDER
    report_folder.mkdir(parents=True, exist_ok=True)
    report_file = report_folder / "forensic_report.md"

    logging.info(f"Building forensic report for case folder: {case_folder}")
    report_content = build_report(case_folder)

    logging.info("Generating forensic summary via Gemini API...")
    summary = generate_summary(report_content, args.local)

    report_content += "\n\n## Forensic Summary & Verdict\n" + summary

    logging.info("Generating Table of Contents...")
    # Generate and insert Table of Contents
    toc = generate_toc(report_content)

    # Replace placeholder in the report
    report_content = report_content.replace("## Table of Contents", "## Table of Contents\n\n" + toc)

    with open(report_file, "w", encoding="utf-8") as f:
        f.write(report_content)

    logging.info(f"Forensic report saved to: {report_file}")

if __name__ == "__main__":
    main()
