# -*- coding: utf-8 -*-
"""
Created on Sat Nov 15 21:30:28 2025

@author: Lorenzo
"""

# -*- coding: utf-8 -*-
"""
Grid search for baseline logistic regression using a precomputed neuron scoring.

"""

# ===========================
# Paths (EDIT THESE)
# ===========================
tensor_path  = r"E:\Lorenzo\activity_tensor_f.h5f"
labels_path  = r"C:\Users\Lorenzo\Documents\content\code\Fluorescence_Traces\data\trials_labels.mat"
scores_path  = r"E:\Lorenzo\univar_best_scores.npy"   # <- your .npy with per-neuron scores

# ===========================
# Imports
# ===========================
import numpy as np
import h5py
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# ===========================
# Config
# ===========================
N_SPLITS      = 5
SEED          = 0

# Progressive k settings
K_MAX_NEURONS = 1400      # max # of neurons to consider (will be clipped to n_neurons)
K_STEP        = 10        # step in k (e.g. 1, 5, 10)

# "100% accuracy" target and tolerance
ACC_TARGET    = 1.0
ACC_TOL       = 0     # treat >= 1 - ACC_TOL as "perfect"
N_K_AFTER_PERFECT = 10   # stop after 10 additional k values once 100% is reached

# Hyperparameter grid for baseline LR
SOLVERS_GRID   = ["lbfgs"]
MAX_ITERS_GRID = [2000]
C_GRID         = [0.01]


# ===========================
# Utils: loading from .mat/.h5
# ===========================
def _try_keys(h5file, candidates):
    """Return first matching dataset for any key in candidates."""
    for k in candidates:
        if k in h5file:
            return np.array(h5file[k])
    raise KeyError(
        f"None of the candidate keys {candidates} found. "
        f"Available keys: {list(h5file.keys())}"
    )


def _load_mat_h5py_first(tensor_path, labels_path):
    """
    Robust loader: try h5py (e.g. MAT v7.3). If it fails, try scipy.io.loadmat.

    Expected keys:
      - Tensor: 'activity_tensor' or 'tensor' or 'X'
      - Labels: 'labels' or 'y' or 'trials_labels'
    """
    X = y = None

    # ---- Try HDF5 (MAT v7.3) ----
    try:
        with h5py.File(tensor_path, "r") as f:
            Xraw = _try_keys(f, ["activity_tensor", "tensor", "X"])
        with h5py.File(labels_path, "r") as f:
            yraw = _try_keys(f, ["labels", "y", "trials_labels"])
        return Xraw, yraw
    except Exception as e_h5:
        print("h5py load failed, trying scipy.io.loadmat...")
        # ---- Fallback: scipy.io.loadmat (older MAT) ----
        try:
            import scipy.io as spio
            Xd = spio.loadmat(tensor_path)
            yd = spio.loadmat(labels_path)

            for k in ["activity_tensor", "tensor", "X"]:
                if k in Xd:
                    X = Xd[k]
                    break
            if X is None:
                raise KeyError(f"Could not find tensor key in {list(Xd.keys())}")

            for k in ["labels", "y", "trials_labels"]:
                if k in yd:
                    y = yd[k]
                    break
            if y is None:
                raise KeyError(f"Could not find labels key in {list(yd.keys())}")

            return X, y
        except Exception as e_mat:
            raise RuntimeError(
                "Failed to load MAT files.\n"
                f"HDF5 error: {e_h5}\nloadmat error: {e_mat}"
            )


# ===========================
# Small helpers
# ===========================
def set_seed(seed=0):
    import random
    try:
        import torch
    except ImportError:
        torch = None
    random.seed(seed)
    np.random.seed(seed)
    if torch is not None:
        try:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass


def time_average(X):
    """Average across time axis.
    X: (trials, neurons, time) -> (trials, neurons)
    """
    if X.ndim != 3:
        raise ValueError(f"X must be 3D (trials, neurons, time), got shape {X.shape}")
    return X.mean(axis=2)


def make_cv_splits(y, n_splits=5, seed=0):
    """Return a fixed list of StratifiedKFold splits so everyone shares them."""
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    dummy_X = np.zeros((len(y), 1), dtype=float)
    return list(cv.split(dummy_X, y))


# ===========================
# Baseline LR: progressive k
# ===========================
def cv_logistic_subset(Xavg, y, splits, neuron_idx, solver, max_iter, C):
    """
    CV accuracy using a subset of neurons (columns of Xavg).

    Parameters
    ----------
    Xavg : (trials, neurons)
    y    : (trials,)
    splits : list of (train_idx, test_idx)
    neuron_idx : 1D array of neuron indices to keep
    solver, max_iter, C : LR hyperparameters

    Returns
    -------
    accs : array of shape (n_splits,)
        Accuracy per outer fold.
    """
    pipe = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", LogisticRegression(
            C=C,
            max_iter=max_iter,
            solver=solver
        ))
    ])
    accs = []
    for tr_idx, te_idx in splits:
        Xtr = Xavg[tr_idx][:, neuron_idx]
        Xte = Xavg[te_idx][:, neuron_idx]
        ytr, yte = y[tr_idx], y[te_idx]
        pipe.fit(Xtr, ytr)
        pred = pipe.predict(Xte)
        accs.append(accuracy_score(yte, pred))
    return np.array(accs)


def progressive_baseline_curve(
    Xavg, y, splits, neuron_order, solver, max_iter, C,
    k_max, k_step, acc_target=1.0, acc_tol=1e-3,
    n_k_after_perfect=10
):
    """
    Compute accuracy vs k curve for a given hyperparameter combo,
    and find the smallest k at which mean CV accuracy reaches ~100%.

    Early stopping: once ~100% accuracy is first reached at some k*,
    continue evaluating only the next n_k_after_perfect k values,
    then stop.

    Returns
    -------
    ks : array of evaluated k values
    acc_means : mean CV accuracies for each evaluated k
    acc_stds  : std of CV accuracies for each evaluated k
    first_k_perfect : smallest k such that mean_acc >= acc_target - acc_tol,
                      or None if never reached.
    """
    n_trials, n_neurons = Xavg.shape
    k_max = min(k_max, n_neurons)

    full_ks = np.arange(10, k_max + 1, k_step, dtype=int)
    if k_max not in full_ks:
        full_ks = np.append(full_ks, k_max)

    acc_means = []
    acc_stds  = []
    first_k_perfect = None
    idx_first_perfect = None

    for idx, k in enumerate(full_ks):
        subset = neuron_order[:k]
        accs_k = cv_logistic_subset(
            Xavg, y, splits,
            neuron_idx=subset,
            solver=solver,
            max_iter=max_iter,
            C=C
        )
        mean_k = accs_k.mean()
        std_k  = accs_k.std(ddof=1)
        acc_means.append(mean_k)
        acc_stds.append(std_k)

        print(f"    k={k:4d} | mean acc={mean_k*100:.2f}% (std={std_k*100:.2f}%)")

        # Check first time we hit ~100% accuracy
        if (first_k_perfect is None) and (mean_k >= acc_target - acc_tol):
            first_k_perfect = k
            idx_first_perfect = idx
            print(f"      -> first ~100% accuracy at k={k}")

        # If we've already hit perfect and evaluated n_k_after_perfect more ks, stop
        if (idx_first_perfect is not None) and (idx >= idx_first_perfect + n_k_after_perfect):
            print(f"      -> stopping early after {n_k_after_perfect} ks beyond k={first_k_perfect}")
            break

    # Truncate ks to the actually evaluated entries
    n_eval = len(acc_means)
    ks = full_ks[:n_eval]

    acc_means = np.array(acc_means)
    acc_stds  = np.array(acc_stds)

    return ks, acc_means, acc_stds, first_k_perfect


# ===========================
# MAIN
# ===========================
if __name__ == "__main__":
    set_seed(SEED)

    # ----- Load tensor, labels -----
    print("Loading tensor and labels from MAT/HDF5 files...")
    Xraw, yraw = _load_mat_h5py_first(tensor_path, labels_path)

    # Expect Xraw shape (Tr, T, N): trials × time × neurons
    if Xraw.ndim != 3:
        raise ValueError(f"Activity tensor must be 3D; got shape {Xraw.shape}")

    # Convert (Tr, T, N) -> (N, T, Tr) -> (Tr, N, T)
    X = np.transpose(Xraw, (2, 1, 0))  # (N, T, Tr)
    X = np.moveaxis(X, -1, 0)          # (Tr, N, T) = (trials, neurons, time)

    # Labels: flatten and encode to 0..C-1
    y = np.array(yraw).ravel()
    le = LabelEncoder()
    y = le.fit_transform(y)

    X = X.astype(np.float32)
    y = y.astype(int)

    print(f"Loaded tensor: X.shape={X.shape} (trials, neurons, time)")
    print(f"Loaded labels: y.shape={y.shape}, classes={np.unique(y)}")

    # Time-averaged features and CV splits
    splits = make_cv_splits(y, n_splits=N_SPLITS, seed=SEED)
    Xavg = time_average(X)  # (trials, neurons)
    n_trials, n_neurons = Xavg.shape
    print(f"\nTime-averaged features: Xavg.shape={Xavg.shape}")

    # ----- Load neuron scores and build ranking -----
    print(f"\nLoading neuron scores from: {scores_path}")
    scores = np.load(scores_path)   # shape (n_neurons,)
    if scores.shape[0] != n_neurons:
        raise ValueError(
            f"Scores length {scores.shape[0]} does not match n_neurons={n_neurons}"
        )

    # Higher score = better neuron (assumption)
    neuron_order = np.argsort(scores)   # best-first

    # ----- Grid search over LR hyperparameters -----
    results = []
    best_idx = None
    best_k_perfect = np.inf  # we want to MINIMIZE this

    print("\n=== Grid search for baseline LR (progressive k) ===")
    for solver in SOLVERS_GRID:
        for max_iter in MAX_ITERS_GRID:
            for C in C_GRID:
                print(f"\n>>> Config: solver={solver}, max_iter={max_iter}, C={C}")
                ks, acc_means, acc_stds, first_k_perfect = progressive_baseline_curve(
                    Xavg, y, splits,
                    neuron_order=neuron_order,
                    solver=solver,
                    max_iter=max_iter,
                    C=C,
                    k_max=K_MAX_NEURONS,
                    k_step=K_STEP,
                    acc_target=ACC_TARGET,
                    acc_tol=ACC_TOL,
                    n_k_after_perfect=N_K_AFTER_PERFECT
                )

                if first_k_perfect is None:
                    # never reached 100% accuracy
                    k_score = np.inf
                    print("    -> never reached ~100% accuracy.")
                else:
                    k_score = first_k_perfect
                    print(f"    -> first k with ~100% accuracy: k={first_k_perfect}")

                cfg_result = {
                    "solver": solver,
                    "max_iter": max_iter,
                    "C": C,
                    "ks": ks,
                    "acc_means": acc_means,
                    "acc_stds": acc_stds,
                    "first_k_perfect": first_k_perfect
                }
                results.append(cfg_result)

                if k_score < best_k_perfect:
                    best_k_perfect = k_score
                    best_idx = len(results) - 1

    # ----- Pick best model -----
    if best_idx is None:
        print("\nNo configuration reached ~100% accuracy within k <= K_MAX_NEURONS.")
        print("You may want to increase K_MAX_NEURONS or relax ACC_TOL.")
    else:
        best_cfg = results[best_idx]
        print("\n=== BEST BASELINE CONFIGURATION (by smallest k reaching ~100% acc) ===")
        print(f"Index: {best_idx}")
        print(f"solver   = {best_cfg['solver']}")
        print(f"max_iter = {best_cfg['max_iter']}")
        print(f"C        = {best_cfg['C']}")
        print(f"first_k_perfect = {best_cfg['first_k_perfect']}")

        # ----- Plot accuracy vs k for the best model -----
        ks_best        = best_cfg["ks"]
        acc_means_best = best_cfg["acc_means"]
        acc_stds_best  = best_cfg["acc_stds"]

        plt.figure(figsize=(7.4, 5.6))
        plt.errorbar(
            ks_best,
            acc_means_best * 100,
            yerr=acc_stds_best * 100,
            fmt="-o",
            capsize=4,
            label="Time-avg LR (CV)"
        )
        plt.axhline(100.0, color="gray", linestyle="--", label="100% accuracy")
        plt.xlabel("# neurons used (k)")
        plt.ylabel("5-fold CV accuracy (%)")
        plt.title(
            f"Baseline LR accuracy vs k\n"
            f"solver={best_cfg['solver']}, max_iter={best_cfg['max_iter']}, C={best_cfg['C']}"
        )
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.show()
