# -*- coding: utf-8 -*-
"""
Grid search for univariate (single-neuron) logistic regression.

For each hyperparameter combination (solver, max_iter, C):

"""

# ===========================
# Paths (EDIT THESE)
# ===========================
tensor_path = r"E:\Lorenzo\activity_tensor_f.h5f"
labels_path = r"C:\Users\Lorenzo\Documents\content\code\Fluorescence_Traces\data\trials_labels.mat"

# ===========================
# Imports
# ===========================
import numpy as np
import h5py

from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


# ===========================
# Config
# ===========================
N_SPLITS = 5
SEED     = 0

# Hyperparameter grid for univariate LR
SOLVERS_GRID   = ["lbfgs"]
MAX_ITERS_GRID = [ 2000]
C_GRID         = [0.01]


# ===========================
# Utils: loading from .mat/.h5
# ===========================
def _try_keys(h5file, candidates):
    """Return first matching dataset for any key in candidates."""
    for k in candidates:
        if k in h5file:
            return np.array(h5file[k])
    raise KeyError(
        f"None of the candidate keys {candidates} found. "
        f"Available keys: {list(h5file.keys())}"
    )


def _load_mat_h5py_first(tensor_path, labels_path):
    """
    Robust loader: try h5py (e.g. MAT v7.3). If it fails, try scipy.io.loadmat.

    Expected keys:
      - Tensor: 'activity_tensor' or 'tensor' or 'X'
      - Labels: 'labels' or 'y' or 'trials_labels'
    """
    X = y = None

    # ---- Try HDF5 (MAT v7.3) ----
    try:
        with h5py.File(tensor_path, "r") as f:
            Xraw = _try_keys(f, ["activity_tensor", "tensor", "X"])
        with h5py.File(labels_path, "r") as f:
            yraw = _try_keys(f, ["labels", "y", "trials_labels"])
        return Xraw, yraw
    except Exception as e_h5:
        print("h5py load failed, trying scipy.io.loadmat...")
        # ---- Fallback: scipy.io.loadmat (older MAT) ----
        try:
            import scipy.io as spio
            Xd = spio.loadmat(tensor_path)
            yd = spio.loadmat(labels_path)

            for k in ["activity_tensor", "tensor", "X"]:
                if k in Xd:
                    X = Xd[k]
                    break
            if X is None:
                raise KeyError(f"Could not find tensor key in {list(Xd.keys())}")

            for k in ["labels", "y", "trials_labels"]:
                if k in yd:
                    y = yd[k]
                    break
            if y is None:
                raise KeyError(f"Could not find labels key in {list(yd.keys())}")

            return X, y
        except Exception as e_mat:
            raise RuntimeError(
                "Failed to load MAT files.\n"
                f"HDF5 error: {e_h5}\nloadmat error: {e_mat}"
            )


# ===========================
# Small helpers
# ===========================
def set_seed(seed=0):
    import random
    try:
        import torch
    except ImportError:
        torch = None
    random.seed(seed)
    np.random.seed(seed)
    if torch is not None:
        try:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass


def time_average(X):
    """Average across time axis.
    X: (trials, neurons, time) -> (trials, neurons)
    """
    if X.ndim != 3:
        raise ValueError(f"X must be 3D (trials, neurons, time), got shape {X.shape}")
    return X.mean(axis=2)


def make_cv_splits(y, n_splits=5, seed=0):
    """Return a fixed list of StratifiedKFold splits so everyone shares them."""
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    dummy_X = np.zeros((len(y), 1), dtype=float)
    return list(cv.split(dummy_X, y))


# ===========================
# Univariate LR helper
# ===========================
def cv_single_neuron_scores(Xavg, y, splits, solver, max_iter, C):
    """
    Compute per-neuron CV accuracy (single neuron at a time) for given hyperparameters.

    Returns
    -------
    acc_per_neuron : (n_neurons,) mean CV accuracy per neuron.
    """
    n_trials, n_neurons = Xavg.shape
    acc_per_neuron = np.zeros(n_neurons, dtype=float)

    for i in range(n_neurons):
        acc_folds = []
        for tr_idx, te_idx in splits:
            Xtr = Xavg[tr_idx, i:i+1]   # single feature
            Xte = Xavg[te_idx, i:i+1]
            ytr, yte = y[tr_idx], y[te_idx]

            pipe = Pipeline([
                ("scaler", StandardScaler()),
                ("clf", LogisticRegression(
                    C=C,
                    max_iter=max_iter,
                    solver=solver
                ))
            ])
            pipe.fit(Xtr, ytr)
            pred = pipe.predict(Xte)
            acc_folds.append(accuracy_score(yte, pred))

        acc_per_neuron[i] = np.mean(acc_folds)

        if (i+1) % max(1, n_neurons // 20) == 0:
            print(f"  [{solver}, it={max_iter}, C={C}] single-neuron CV {i+1}/{n_neurons} done...")

    return acc_per_neuron


# ===========================
# MAIN
# ===========================
if __name__ == "__main__":
    set_seed(SEED)

    print("Loading tensor and labels from MAT/HDF5 files...")
    Xraw, yraw = _load_mat_h5py_first(tensor_path, labels_path)

    # Expect Xraw shape (Tr, T, N): trials × time × neurons
    if Xraw.ndim != 3:
        raise ValueError(f"Activity tensor must be 3D; got shape {Xraw.shape}")

    # Convert (Tr, T, N) -> (N, T, Tr) -> (Tr, N, T)
    X = np.transpose(Xraw, (2, 1, 0))  # (N, T, Tr)
    X = np.moveaxis(X, -1, 0)          # (Tr, N, T) = (trials, neurons, time)

    # Labels: flatten and encode to 0..C-1
    y = np.array(yraw).ravel()
    le = LabelEncoder()
    y = le.fit_transform(y)

    X = X.astype(np.float32)
    y = y.astype(int)

    print(f"Loaded tensor: X.shape={X.shape} (trials, neurons, time)")
    print(f"Loaded labels: y.shape={y.shape}, classes={np.unique(y)}")

    # Prepare CV splits and time-averaged features
    splits = make_cv_splits(y, n_splits=N_SPLITS, seed=SEED)
    Xavg = time_average(X)  # (trials, neurons)
    n_trials, n_neurons = Xavg.shape
    print(f"\nTime-averaged features: Xavg.shape={Xavg.shape}")

    # ===========================
    # GRID SEARCH (UNIVARIATE)
    # ===========================
    results = []             # list of dicts with info per model
    median_scores = []       # vector of medians per model
    q25_scores    = []       # vector of 25% quantiles per model
    q75_scores    = []       # vector of 75% quantiles per model

    print("\n=== Grid search on UNIVARIATE logistic regressors (neuron by neuron) ===")
    for solver in SOLVERS_GRID:
        for max_iter in MAX_ITERS_GRID:
            for C in C_GRID:
                print(f"\n>>> Evaluating config: solver={solver}, max_iter={max_iter}, C={C}")
                acc_single = cv_single_neuron_scores(
                    Xavg, y, splits,
                    solver=solver,
                    max_iter=max_iter,
                    C=C
                )

                # distribution over neurons
                median_single = np.quantile(acc_single, 0.5)
                q25_single    = np.quantile(acc_single, 0.25)
                q75_single    = np.quantile(acc_single, 0.75)

                results.append({
                    "solver": solver,
                    "max_iter": max_iter,
                    "C": C,
                    "median": median_single,
                    "q25": q25_single,
                    "q75": q75_single
                })
                median_scores.append(median_single)
                q25_scores.append(q25_single)
                q75_scores.append(q75_single)

                print(f"   -> median single-neuron acc = {median_single*100:.2f}%")
                print(f"   -> 25%   single-neuron acc = {q25_single*100:.2f}%")
                print(f"   -> 75%   single-neuron acc = {q75_single*100:.2f}%")

    median_scores = np.array(median_scores)
    q25_scores    = np.array(q25_scores)
    q75_scores    = np.array(q75_scores)

    # ===========================
    # PICK BEST MODEL (by median)
    # ===========================
    best_idx = int(np.argmax(median_scores))
    best_cfg = results[best_idx]

    print("\n=== FINAL GRID SEARCH RESULT (UNIVARIATE, median-based) ===")
    print(f"Number of models tested: {len(results)}")
    print("Best model index:", best_idx)
    print("Best config:")
    print(best_cfg)

    print("\nVector of median_scores (per model):")
    print(median_scores)

    print("\nVector of q25_scores (per model):")
    print(q25_scores)

    print("\nVector of q75_scores (per model):")
    print(q75_scores)

    # Optional: save to disk
    np.save("univar_lr_median_scores.npy", median_scores)
    np.save("univar_lr_q25_scores.npy", q25_scores)
    np.save("univar_lr_q75_scores.npy", q75_scores)

