# -*- coding: utf-8 -*-
"""
Created on Sat Nov 15 22:49:12 2025

@author: Lorenzo
"""

# -*- coding: utf-8 -*-
"""
Created on Sat Nov 15 22:21:29 2025

@author: Lorenzo
"""

# -*- coding: utf-8 -*-
"""
Hyperparameter scan for CP-logistic regression with progressively more WORST neurons.

"""

# ===========================
# Paths (EDIT THESE)
# ===========================
tensor_path = r"E:\Lorenzo\activity_tensor_f.h5f"
labels_path = r"C:\Users\Lorenzo\Documents\content\code\Fluorescence_Traces\data\trials_labels.mat"
scores_path = r"E:\Lorenzo\univar_best_scores.npy"   # per-neuron scores (e.g. univariate accuracies)

# ===========================
# Imports USATO
# ===========================
import numpy as np
import h5py
import matplotlib.pyplot as plt

from dataclasses import dataclass

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn

# ===========================
# Global config
# ===========================
N_SPLITS_OUTER   = 5
SEED             = 0

K_MAX            = 1400   # max #neurons to consider (will be clipped to n_neurons)
K_STEP           = 10     # step size in k (e.g. 1 -> 1,2,3,...; 5 -> 5,10,15,...)

INNER_VAL_FRAC   = 0.2

# Hyperparameter grids for CP
R_grid      = [5]          # candidate ranks
WD_grid     = [1e-3]     # weight_decay values
LR_grid     = [1e-3]     # learning rates
MAXE_grid   = [2500]           # max_epochs values
PAT_grid    = [30]             # patience values

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ===========================
# Utils: loading from .mat/h5
# ===========================
def _try_keys(h5file, candidates):
    """Return first matching dataset for any key in candidates."""
    for k in candidates:
        if k in h5file:
            return np.array(h5file[k])
    raise KeyError(f"None of the candidate keys {candidates} found. "
                   f"Available keys: {list(h5file.keys())}")


def _load_mat_h5py_first(tensor_path, labels_path):
    """
    Robust loader: try h5py (MAT v7.3). If it fails, try scipy.io.loadmat.

    Expected keys:
      - Tensor: 'activity_tensor' or 'tensor' or 'X'
      - Labels: 'labels' or 'y' or 'trials_labels'
    """
    X = y = None

    # ---- Try HDF5 (MAT v7.3) ----
    try:
        with h5py.File(tensor_path, "r") as f:
            Xraw = _try_keys(f, ["activity_tensor", "tensor", "X"])
        with h5py.File(labels_path, "r") as f:
            yraw = _try_keys(f, ["labels", "y", "trials_labels"])
        return Xraw, yraw
    except Exception as e_h5:
        print("h5py load failed, trying scipy.io.loadmat...")
        # ---- Fallback: scipy.io.loadmat (older MAT) ----
        try:
            import scipy.io as spio
            Xd = spio.loadmat(tensor_path)
            yd = spio.loadmat(labels_path)

            for k in ["activity_tensor", "tensor", "X"]:
                if k in Xd:
                    X = Xd[k]
                    break
            if X is None:
                raise KeyError(f"Could not find tensor key in {list(Xd.keys())}")

            for k in ["labels", "y", "trials_labels"]:
                if k in yd:
                    y = yd[k]
                    break
            if y is None:
                raise KeyError(f"Could not find labels key in {list(yd.keys())}")

            return X, y
        except Exception as e_mat:
            raise RuntimeError(
                f"Failed to load MAT files.\n"
                f"HDF5 error: {e_h5}\nloadmat error: {e_mat}"
            )

# ===========================
# Small helpers
# ===========================
def set_seed(seed=0):
    import random
    random.seed(seed)
    np.random.seed(seed)
    try:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    except Exception:
        pass


def make_cv_splits(y, n_splits=5, seed=0):
    """Return a fixed list of StratifiedKFold splits so everyone shares them."""
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    dummy_X = np.zeros((len(y), 1), dtype=float)
    return list(cv.split(dummy_X, y))


def time_average(X):
    """Average across time axis.
    X: (trials, neurons, time) -> (trials, neurons)
    """
    if X.ndim != 3:
        raise ValueError(f"X must be 3D (trials, neurons, time), got shape {X.shape}")
    return X.mean(axis=2)

# ===========================
# Standardization for tensors
# ===========================
def fit_scaler_on_tensor(X_train):
    """
    Fit a StandardScaler on a 3D tensor (trials, neurons, time)
    by flattening (trials, neurons*time).
    """
    N, I, T = X_train.shape
    X_flat = X_train.reshape(N, I * T)
    scaler = StandardScaler()
    scaler.fit(X_flat)
    return scaler


def transform_tensor_with_scaler(X, scaler):
    """
    Apply a previously-fitted StandardScaler to a 3D tensor.
    """
    N, I, T = X.shape
    X_flat = X.reshape(N, I * T)
    X_flat_std = scaler.transform(X_flat)
    return X_flat_std.reshape(N, I, T)

# ===========================
# CP-logistic model
# ===========================
class CPLogisticRegression(nn.Module):
    """
    CP-decomposed logistic model.

    For each trial n:
      - X_n: (I, J) neuron x time
      - Rank-R CP factors A (I x R), B (J x R)
      - Features z_n,r = a_r^T X_n b_r
      - Logits = Z @ W_class + bias
    """

    def __init__(self, I, J, R, C):
        super().__init__()
        self.A = nn.Parameter(torch.randn(I, R) * 0.1)
        self.B = nn.Parameter(torch.randn(J, R) * 0.1)
        self.W_class = nn.Parameter(torch.randn(R, C) * 0.1)
        self.bias = nn.Parameter(torch.zeros(C))

    def features(self, X):
        """
        X: (N, I, J)
        B: (J, R)
        A: (I, R)

        Compute:
          XB = X ×_time B  -> (N, I, R)
          Z  = sum_i XB_i * A_i -> (N, R)
        """
        XB = torch.tensordot(X, self.B, dims=([2], [0]))  # (N, I, R)
        Z  = (XB * self.A.unsqueeze(0)).sum(dim=1)        # (N, R)
        return Z

    def forward(self, X):
        Z = self.features(X)
        logits = Z @ self.W_class + self.bias
        return logits


def _cp_gauge_fix(model: CPLogisticRegression):
    """
    Normalize columns of A and B, absorb norms into W_class for readability
    and numerical stability.
    """
    with torch.no_grad():
        An = torch.linalg.norm(model.A, dim=0).clamp_min(1e-8)
        Bn = torch.linalg.norm(model.B, dim=0).clamp_min(1e-8)
        scale = An * Bn
        model.A /= An
        model.B /= Bn
        model.W_class *= scale.unsqueeze(1)


@dataclass
class CPTrainConfig:
    rank: int = 3
    lr: float = 1e-3
    weight_decay: float = 5e-4
    batch_size: int = 64
    max_epochs: int = 300
    patience: int = 30
    seed: int = 0


def train_cp_model(X_train, y_train, X_val, y_val, cfg: CPTrainConfig, verbose=False):
    """
    Train CP-logistic with early stopping on validation loss.

    X_train, X_val: (N, I, J), standardized.
    y_train, y_val: 1D arrays.
    """
    set_seed(cfg.seed)
    I, J = X_train.shape[1], X_train.shape[2]
    C = len(np.unique(y_train))

    model = CPLogisticRegression(I, J, cfg.rank, C).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(),
                           lr=cfg.lr,
                           weight_decay=cfg.weight_decay)
    loss_fn = nn.CrossEntropyLoss()

    Xtr = torch.tensor(X_train, dtype=torch.float32, device=DEVICE)
    ytr = torch.tensor(y_train, dtype=torch.long, device=DEVICE)
    Xva = torch.tensor(X_val, dtype=torch.float32, device=DEVICE)
    yva = torch.tensor(y_val, dtype=torch.long, device=DEVICE)

    best_val = float("inf")
    best_state = None
    no_improve = 0

    idx_all = np.arange(len(Xtr))

    for epoch in range(1, cfg.max_epochs + 1):
        model.train()
        np.random.shuffle(idx_all)
        for s in range(0, len(idx_all), cfg.batch_size):
            sel = idx_all[s:s+cfg.batch_size]
            xb, yb = Xtr[sel], ytr[sel]
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            _cp_gauge_fix(model)

        # validation
        model.eval()
        with torch.no_grad():
            vloss = loss_fn(model(Xva), yva).item()

        if verbose and (epoch == 1 or epoch % 10 == 0):
            print(f"  epoch {epoch:03d}  val_loss={vloss:.4f}")

        if vloss + 1e-7 < best_val:
            best_val = vloss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= cfg.patience:
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})

    model.eval()
    with torch.no_grad():
        logits_val = model(Xva)
        val_pred = logits_val.argmax(dim=1).cpu().numpy()
        val_acc = (val_pred == y_val).mean()

    return model, float(val_acc)


def predict_cp(model, X):
    """Predict class labels from CP model. X: (N, I, J), standardized."""
    with torch.no_grad():
        X_t = torch.tensor(X, dtype=torch.float32, device=DEVICE)
        logits = model(X_t)
        pred = logits.argmax(dim=1).cpu().numpy()
        return pred


# ===========================
# CV for fixed CP config over k
# ===========================
def cp_cv_for_fixed_cfg_over_k(
    X, y, worst_order, splits, ks,
    cp_cfg: CPTrainConfig,
    inner_val_frac=INNER_VAL_FRAC,
    seed=SEED,
    verbose=False
):
    """
    For a fixed CP configuration (rank, weight_decay, lr, max_epochs, patience),
    evaluate CP-logistic performance vs k (#neurons, WORST-first).

    Returns
    -------
    cp_means : (len(ks),)
    cp_stds  : (len(ks),)
    """
    cp_means = []
    cp_stds  = []

    for k in ks:
        neuron_idx = worst_order[:k]
        fold_acc = []

        for fold, (tr_idx, te_idx) in enumerate(splits, start=1):
            X_tr_full = X[tr_idx][:, neuron_idx, :]   # (N_tr, k, T)
            X_te_full = X[te_idx][:, neuron_idx, :]   # (N_te, k, T)
            y_tr = y[tr_idx]
            y_te = y[te_idx]

            # inner split for early stopping only (no HP tuning)
            X_tr_in, X_val_in, y_tr_in, y_val_in = train_test_split(
                X_tr_full, y_tr, test_size=inner_val_frac,
                stratify=y_tr, random_state=seed
            )

            # standardize based on full training fold (X_tr_full)
            scaler_outer = fit_scaler_on_tensor(X_tr_full)
            Xtr_std  = transform_tensor_with_scaler(X_tr_full, scaler_outer)
            Xval_std = transform_tensor_with_scaler(X_val_in,    scaler_outer)
            Xte_std  = transform_tensor_with_scaler(X_te_full,   scaler_outer)

            # train CP model with fixed cfg
            model, _ = train_cp_model(Xtr_std, y_tr, Xval_std, y_val_in,
                                      cp_cfg, verbose=False)

            # test accuracy
            y_pred = predict_cp(model, Xte_std)
            acc = accuracy_score(y_te, y_pred)
            fold_acc.append(acc)

        fold_acc = np.array(fold_acc)
        cp_means.append(fold_acc.mean())
        cp_stds.append(fold_acc.std(ddof=1))

        if verbose:
            print(f"  k={k:5d} | mean acc={fold_acc.mean()*100:.2f}% "
                  f"(std={fold_acc.std(ddof=1)*100:.2f}%)")

    return np.array(cp_means), np.array(cp_stds)


# ===========================
# MAIN
# ===========================
if __name__ == "__main__":
    set_seed(SEED)

    print("Loading tensor and labels from MAT/HDF5 files...")
    Xraw, yraw = _load_mat_h5py_first(tensor_path, labels_path)

    # Expect Xraw shape (Tr, T, N): trials × time × neurons
    if Xraw.ndim != 3:
        raise ValueError(f"Activity tensor must be 3D; got shape {Xraw.shape}")

    # Convert (Tr, T, N) -> (N, T, Tr) -> (Tr, N, T)
    X = np.transpose(Xraw, (2, 1, 0))  # (N, T, Tr)
    X = np.moveaxis(X, -1, 0)          # (Tr, N, T) = (trials, neurons, time)

    # Labels: flatten and encode to 0..C-1
    y = np.array(yraw).ravel()
    le = LabelEncoder()
    y = le.fit_transform(y)

    X = X.astype(np.float32)
    y = y.astype(int)

    print(f"Loaded tensor: X.shape={X.shape} (trials, neurons, time)")
    print(f"Loaded labels: y.shape={y.shape}, classes={np.unique(y)}")

    # Class balance
    classes, counts = np.unique(y, return_counts=True)
    print("Class counts:", dict(zip(classes, counts)))
    majority = counts.max() / counts.sum()
    print(f"Majority baseline accuracy: {majority*100:.2f}%")

    # Prepare CV splits (outer) and optional time-averaged features (for debugging / info)
    splits = make_cv_splits(y, n_splits=N_SPLITS_OUTER, seed=SEED)
    Xavg = time_average(X)  # (trials, neurons)
    n_trials, n_neurons = Xavg.shape
    print(f"\nTime-averaged features: Xavg.shape={Xavg.shape}")

    # ===========================
    # Load neuron scores and build WORST-first order
    # ===========================
    print(f"\nLoading neuron scores from: {scores_path}")
    scores = np.load(scores_path)   # shape (n_neurons,)
    if scores.shape[0] != n_neurons:
        raise ValueError(
            f"Scores length {scores.shape[0]} does not match n_neurons={n_neurons}"
        )

    # Assume higher score = better neuron (e.g. univariate accuracy).
    # For WORST-first, we sort ascending (lowest score first).
    worst_order = np.argsort(scores)

    np.save("worst_neuron_order_for_cp_HPscan_loaded.npy", worst_order)
    print("Neuron order: worst -> best according to loaded scores.")
    print("Saved 'worst_neuron_order_for_cp_HPscan_loaded.npy'.")

    # ===========================
    # Define k values (WORST-first), capped by K_MAX and with step K_STEP
    # ===========================
    k_max_eff = min(K_MAX, n_neurons)

    ks = np.arange(K_STEP, k_max_eff + 1, K_STEP, dtype=int)
    # If you ALSO want k=1 even when K_STEP>1, uncomment:
    # if ks[0] != 1:
    #     ks = np.concatenate(([1], ks))

    print("ks =", ks)
    print(f"\nWill evaluate CP-logistic performance for k in {ks[0]}..{ks[-1]} neurons "
          f"({len(ks)} points, WORST-first; K_MAX={k_max_eff}, K_STEP={K_STEP})")

    # ===========================
    # Hyperparameter scan
    # ===========================
    all_results = {}  # (R, wd, lr, maxE, pat) -> dict with ks, means, stds

    for R in R_grid:
        for wd in WD_grid:
            for lr in LR_grid:
                for maxE in MAXE_grid:
                    for pat in PAT_grid:
                        print(f"\n=== CP config: R={R}, wd={wd}, lr={lr}, "
                              f"max_epochs={maxE}, patience={pat} ===")
                        cp_cfg = CPTrainConfig(
                            rank=R,
                            lr=lr,
                            weight_decay=wd,
                            batch_size=64,
                            max_epochs=maxE,
                            patience=pat,
                            seed=SEED
                        )

                        cp_means, cp_stds = cp_cv_for_fixed_cfg_over_k(
                            X, y, worst_order, splits, ks,
                            cp_cfg, inner_val_frac=INNER_VAL_FRAC,
                            seed=SEED, verbose=True
                        )

                        key = (R, wd, lr, maxE, pat)
                        all_results[key] = {
                            "ks": ks.copy(),
                            "mean": cp_means,
                            "std": cp_stds,
                        }

                        # tag for plots and filenames
                        tag = f"R{R}_wd{wd:g}_lr{lr:g}_ep{maxE}_pat{pat}"

                        # Plot curve for this config
                        plt.figure(figsize=(7.0, 5.0))
                        plt.errorbar(ks, cp_means * 100, yerr=cp_stds * 100,
                                     fmt="-o", capsize=4,
                                     label=f"CP ({tag})")
                        plt.axhline(majority * 100, color="gray", linestyle="--",
                                    label="Majority baseline")
                        plt.xlabel("# neurons used (WORST-first by loaded scores)")
                        plt.ylabel("5-fold CV accuracy (%)")
                        plt.title(f"CP-logistic acc vs #WORST neurons\n{tag}")
                        plt.grid(True, alpha=0.3)
                        plt.legend()
                        plt.tight_layout()
                        plt.show()

                        # Save arrays per config
                        np.save(f"cp_HPscan_{tag}_ks.npy", ks)
                        np.save(f"cp_HPscan_{tag}_mean.npy", cp_means)
                        np.save(f"cp_HPscan_{tag}_std.npy", cp_stds)

    print("\nHyperparameter scan finished.")
    print("Results stored in 'all_results' and saved as cp_HPscan_*_.npy files.")
