import os
import sys
import argparse
import logging
from pathlib import Path
from datetime import datetime
from google import genai
import time
from google.genai.errors import ClientError, ServerError
from collections import deque
import threading
import requests
import json

# ==========================
# GLOBAL CONFIGURATION
# ==========================
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from config import (
    GEMINI_API_KEY,
    GEMINI_API_KEY_2,
    GEMINI_API_KEY_3,
    GEMINI_API_KEY_4,
    OUTPUT_SUBDIR_MEMORY,
    MAX_TPM,
    MAX_RPM,
    DB_JSON_PATH,
    LOCK_PATH,
    RETRY_WAIT,
    MAX_RETRIES,
    ASK_GEMINI_LOG_LEVEL,
    OLLAMA_BASE_URL,
    OLLAMA_MODEL,
    FORENSIC_MEMORY_PROMPT,
    FORENSIC_MEMORY_PROMPT_WCLEAN,
    FORENSIC_CHUNK_MEMORY_PROMPT
)

GEMINI_API_KEY_ENV_VAR = "GEMINI_API_KEY"
DEFAULT_GEMINI_API_KEY = "ASD"
TIMESTAMP_FORMAT = "%Y%m%d_%H%M%S"

OUTPUT_SUBDIR = OUTPUT_SUBDIR_MEMORY

# ==========================
# RATE LIMITER
# ==========================

import json
import time
import threading
from filelock import FileLock

class RateLimiter:
    def __init__(self, max_rpm: int, max_tpm: int):
        self.max_rpm = max_rpm
        self.max_tpm = max_tpm
        self.lock = threading.Lock()
        # ensure JSON file exists
        try:
            with open(DB_JSON_PATH, "r") as f:
                pass
        except FileNotFoundError:
            with open(DB_JSON_PATH, "w") as f:
                json.dump({"requests": []}, f)

    def estimate_tokens(self, text: str) -> int:
        return int(len(text.split()) * 1.3)

    def _load_data(self):
        with FileLock(LOCK_PATH):
            with open(DB_JSON_PATH, "r") as f:
                return json.load(f)

    def _save_data(self, data):
        with FileLock(LOCK_PATH):
            with open(DB_JSON_PATH, "w") as f:
                json.dump(data, f)

    def _cleanup_old(self, data, now):
        data["requests"] = [r for r in data["requests"] if now - r["ts"] <= 60]
        return data

    def wait_for_slot(self, prompt: str):
        tokens = self.estimate_tokens(prompt)
        if tokens > self.max_tpm:
            raise ValueError(
                f"Prompt too large: {tokens} tokens exceeds max per-minute limit ({self.max_tpm})"
            )
        while True:
            with self.lock:
                now = time.time()
                data = self._load_data()
                data = self._cleanup_old(data, now)

                req_count = len(data["requests"])
                token_count = sum(r["tokens"] for r in data["requests"])
                print(f"📊 Stato finestra 60s → richieste={req_count}, token={token_count}")

                if req_count < self.max_rpm and token_count + tokens <= self.max_tpm:
                    # add new request
                    data["requests"].append({"ts": now, "tokens": tokens})
                    self._save_data(data)
                    return
                else:
                    if data["requests"]:
                        oldest_ts = min(r["ts"] for r in data["requests"])
                        wait_time = 60 - (now - oldest_ts)  # wait until the oldest request falls out of window
                        wait_time = max(wait_time, 0.1)     # don't wait 0 or negative
                    else:
                        wait_time = 1

                    print(f"⏳ Rate limit raggiunto → attendo {wait_time:.2f}s")
            time.sleep(wait_time)

rate_limiter = RateLimiter(max_rpm=MAX_RPM, max_tpm=MAX_TPM)
# ==========================
# INITIALIZATION
# ==========================
api_keys = deque([GEMINI_API_KEY, GEMINI_API_KEY_2, GEMINI_API_KEY_3, GEMINI_API_KEY_4])
current_api_key = api_keys[0]

logging.basicConfig(
    level=ASK_GEMINI_LOG_LEVEL,
    format="%(asctime)s [%(levelname)s] %(message)s"
)

# ==========================
# UTILITY FUNCTIONS
# ==========================

def read_file_content(file_path: Path) -> str:
    return file_path.read_text(encoding="utf-8").strip()

def save_response_to_file(output_dir: Path, report_name: str, response: str) -> None:
    output_filename = f"{report_name}.md"
    output_path = output_dir / output_filename
    output_path.write_text(response, encoding="utf-8")
    logging.info(f"✅ Report saved to: {output_path}")

def generate_gemini_response(prompt: str, model: str, temperature: float) -> str:
    global current_api_key

    client = genai.Client(api_key=current_api_key)

    if len(prompt) * 1.3 >= 125000:
        logging.warning(f"prompt exceeded limit of {MAX_TPM} tokens, truncating...")
        prompt = prompt[:125000]

    # --- RATE LIMIT CONTROL ---
    rate_limiter.wait_for_slot(prompt)

    for attempt in range(1, MAX_RETRIES + 1):
        try:
            logging.debug("=== FULL PROMPT SENT TO MODEL ===\n%s", prompt)
            response = client.models.generate_content(
                model=model,
                contents=[{"role": "user", "parts": [{"text": prompt}]}],
                config={"temperature": temperature}
            )

            if response:
                response_text = None

                # Direct text field
                if hasattr(response, "text") and response.text:
                    response_text = response.text.strip()

                # Candidates field
                elif hasattr(response, "candidates") and response.candidates:
                    for c in response.candidates:
                        if c.content and getattr(c.content, "parts", None):
                            for part in c.content.parts or []:
                                if hasattr(part, "text") and part.text:
                                    response_text = part.text.strip()
                                    break
                        if response_text:
                            break

                if response_text:
                    logging.debug("=== MODEL RESPONSE ===\n%s", response_text)
                    return response_text
                else:
                    logging.warning(f"⚠️ Empty response.text from Gemini, waiting {RETRY_WAIT} sec...")
                    time.sleep(RETRY_WAIT)
            else:
                logging.warning(f"⚠️ Empty response payload from Gemini, waiting {RETRY_WAIT} sec...")
                time.sleep(RETRY_WAIT)

        except (ClientError, ServerError) as e:
            err_str = str(e)
            # Rotate API key if daily quota exceeded
            if "GenerateRequestsPerDayPerProjectPerModel-FreeTier" in err_str:
                logging.warning(f"Quota exceeded for API key {current_api_key}, rotating key...")

                # --- REMOVE LAST REQUEST FROM DB ---
                with rate_limiter.lock:
                    data = rate_limiter._load_data()
                    if data["requests"]:
                        removed = data["requests"].pop()  # delete last request
                        logging.debug(f"Removed last request: {removed}")
                    rate_limiter._save_data(data)

                api_keys.rotate(-1)
                current_api_key = api_keys[0]
                client = genai.Client(api_key=current_api_key)
                rate_limiter.wait_for_slot(prompt)
                continue

            logging.warning(
                f"⚠️ Gemini API error ({type(e).__name__}): {e}. "
                f"Retrying in {RETRY_WAIT}s (attempt {attempt}/{MAX_RETRIES})..."
            )
            time.sleep(RETRY_WAIT)
            rate_limiter.wait_for_slot(prompt)

    raise RuntimeError("Max retries exceeded for Gemini API request")

def generate_local_response(prompt: str) -> str:
    headers = {"Content-Type": "application/json"}
    payload = {
        "model": OLLAMA_MODEL,
        "prompt": prompt,
        "stream": False,  # Disable streaming for simpler response handling
    }
    response = requests.post(
        f"{OLLAMA_BASE_URL}/api/generate",
        headers=headers,
        data=json.dumps(payload)
    )
    
    # Check for HTTP errors
    response.raise_for_status()
    
    # Parse and return response
    result = response.json()
    print("[INFO] Ollama generation successful.")
    content = result.get("response", "")
    print(f"[DEBUG] Ollama response: {content}")
    return content.strip()

def build_prompt(clean_content: str | None, target_content: str) -> str:
    clean_section = f"    • A clean reference machine: <clean_out>\n{clean_content}" if clean_content else "--not available--"
    
    if clean_content:
        return FORENSIC_MEMORY_PROMPT_WCLEAN.format(
            clean_section=clean_section,
            target_content=target_content
        )
    else:
        return FORENSIC_MEMORY_PROMPT.format(
            clean_section=clean_section,
            target_content=target_content
        )

def chunk_by_rows(text: str, max_rows: int = 200) -> list[str]:
    """
    Split the ls-style text into chunks, preserving complete rows.

    Args:
        text (str): Full ls output.
        max_rows (int): Max number of rows per chunk.

    Returns:
        list[str]: List of chunked ls outputs.
    """
    lines = text.strip().splitlines()
    chunks = []
    for i in range(0, len(lines), max_rows):
        chunk = "\n".join(lines[i:i+max_rows])
        chunks.append(chunk)
    return chunks

def process_forensic_report(clean_path: str | None, target_path: Path, output_dir: Path, model: str, temperature: float, local: bool) -> None:
    clean_content = read_file_content(Path(clean_path)) if clean_path else None
    target_content = read_file_content(target_path)
    prompt = build_prompt(clean_content, target_content)
    print(f"[DEBUG] Generated prompt:\n{prompt}")
    output_dir.mkdir(parents=True, exist_ok=True)

    logging.info(f"🧠 Sending data to {model} via Gemini")
    if local:
        if len(prompt) >= 4000:
            logging.warning(f"prompt exceeded limit of 4000 characters, using chunking strategy...")
            chunks = chunk_by_rows(target_content, max_rows=100)
            # if more than 15 chunks, keep only first 7 and last 7 for local limitations (time)
            if len(chunks) > 15:
                chunks = chunks[:7] + chunks[-7:]
            intermediate_responses = []
            for idx, chunk in enumerate(chunks):
                print(f"[INFO] Processing chunk {idx+1}/{len(chunks)}")
                chunk_prompt = FORENSIC_CHUNK_MEMORY_PROMPT.format(
                    chunk_text=chunk,
                    context="\n".join(intermediate_responses) if intermediate_responses else ""
                )
                print(f"[DEBUG] Chunk prompt:\n{chunk_prompt}")
                chunk_response = generate_local_response(chunk_prompt)
                if idx == 0:
                    #copy the first two lines of the real input to the content (starting folder line)
                    chunk_response = "\n".join(chunk.splitlines()[0:2]) + "\n" + chunk_response
                else:
                    # Remove the first two lines in subsequent chunks (header + possible repeated context)
                    chunk_response = "\n".join(chunk_response.splitlines()[2:])
                    # remove "OFFSET (V) ..." lines if present
                    chunk_response = chunk_response.replace("OFFSET (V)      PID     TID     PPID    COMM    UID     GID     EUID    EGID    CREATION TIME   File output\n", "")
                chunk_response = chunk_response.replace("```", "").strip()
                intermediate_responses.append(chunk_response)
            combined_response = "\n".join(intermediate_responses)
            final_prompt = FORENSIC_MEMORY_PROMPT.format(
                clean_section="--not available--",
                target_content=combined_response
            )
            response = generate_local_response(final_prompt)
        else:
            response = generate_local_response(prompt)
    else:
        response = generate_gemini_response(prompt, model=model, temperature=temperature)
    save_response_to_file(output_dir, target_path.stem, response)

# ==========================
# ARGUMENT PARSING
# ==========================

def parse_arguments():
    parser = argparse.ArgumentParser(description="Generate forensic report using Gemini")
    parser.add_argument("--clean", required=False, help="Path to the clean reference file (optional)")
    parser.add_argument("--input", required=True, help="Path to the target system output file")
    parser.add_argument("--case_folder", required=True, help="Directory of the case")
    parser.add_argument("--model", default="gemini-2.5-pro", help="Gemini model to use")
    parser.add_argument("--temperature", type=float, default=0.3, help="Sampling temperature")
    parser.add_argument("--debug", action="store_true", help="Enable debug logging")
    parser.add_argument("--local", action="store_true", help="Enable local mode")
    return parser.parse_args()

# ==========================
# MAIN
# ==========================

def main():
    args = parse_arguments()
    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    process_forensic_report(
        clean_path=args.clean,
        target_path=Path(args.input),
        output_dir=Path(args.case_folder) / OUTPUT_SUBDIR,
        model=args.model,
        temperature=args.temperature,
        local=args.local
    )

if __name__ == "__main__":
    main()