# -*- coding: utf-8 -*-


# ===========================
# Paths (EDIT THESE)
# ===========================
tensor_path = r"E:\Lorenzo\activity_tensor_f.h5f"
labels_path = r"C:\Users\Lorenzo\Documents\content\code\Fluorescence_Traces\data\trials_labels.mat"
scores_path = r"E:\Lorenzo\univar_best_scores.npy"   # per-neuron scores (e.g. univariate accuracies)

# ===========================
# Imports
# ===========================
import numpy as np
import h5py
import matplotlib.pyplot as plt

from dataclasses import dataclass

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn

# ===========================
# USER SETTINGS
# ===========================
SEED             = 0

# Number of neurons to use (WORST-first according to scores)
NUM_NEURONS      = 20        # <-- 2 worst neurons, as requested

# Splits
TEST_FRAC        = 0.2      # fraction of trials for final test
VAL_FRAC         = 0.2      # fraction of *training* for CP early stopping

# CP hyperparameters
CP_RANK          = 5
CP_LR            = 1e-3
CP_WEIGHT_DECAY  = 1e-3
CP_MAX_EPOCHS    = 2500
CP_PATIENCE      = 30
CP_BATCH_SIZE    = 64

# Logistic Regression hyperparameters
LR_C             = 0.01
LR_MAX_ITER      = 2000
LR_PENALTY       = "l2"
LR_SOLVER        = "lbfgs"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ===========================
# Utils: loading from .mat/h5
# ===========================
def _try_keys(h5file, candidates):
    """Return first matching dataset for any key in candidates."""
    for k in candidates:
        if k in h5file:
            return np.array(h5file[k])
    raise KeyError(
        f"None of the candidate keys {candidates} found. "
        f"Available keys: {list(h5file.keys())}"
    )


def _load_mat_h5py_first(tensor_path, labels_path):
    """
    Robust loader: try h5py (MAT v7.3). If it fails, try scipy.io.loadmat.

    Expected keys:
      - Tensor: 'activity_tensor' or 'tensor' or 'X'
      - Labels: 'labels' or 'y' or 'trials_labels'
    """
    X = y = None

    # ---- Try HDF5 (MAT v7.3) ----
    try:
        with h5py.File(tensor_path, "r") as f:
            Xraw = _try_keys(f, ["activity_tensor", "tensor", "X"])
        with h5py.File(labels_path, "r") as f:
            yraw = _try_keys(f, ["labels", "y", "trials_labels"])
        return Xraw, yraw
    except Exception as e_h5:
        print("h5py load failed, trying scipy.io.loadmat...")
        # ---- Fallback: scipy.io.loadmat (older MAT) ----
        try:
            import scipy.io as spio
            Xd = spio.loadmat(tensor_path)
            yd = spio.loadmat(labels_path)

            for k in ["activity_tensor", "tensor", "X"]:
                if k in Xd:
                    X = Xd[k]
                    break
            if X is None:
                raise KeyError(f"Could not find tensor key in {list(Xd.keys())}")

            for k in ["labels", "y", "trials_labels"]:
                if k in yd:
                    y = yd[k]
                    break
            if y is None:
                raise KeyError(f"Could not find labels key in {list(yd.keys())}")

            return X, y
        except Exception as e_mat:
            raise RuntimeError(
                f"Failed to load MAT files.\n"
                f"HDF5 error: {e_h5}\nloadmat error: {e_mat}"
            )

# ===========================
# Small helpers
# ===========================
def set_seed(seed=0):
    import random
    random.seed(seed)
    np.random.seed(seed)
    try:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    except Exception:
        pass


def time_average(X):
    """Average across time axis.
    X: (trials, neurons, time) -> (trials, neurons)
    """
    if X.ndim != 3:
        raise ValueError(f"X must be 3D (trials, neurons, time), got shape {X.shape}")
    return X.mean(axis=2)


# ===========================
# Standardization for tensors
# ===========================
def fit_scaler_on_tensor(X_train):
    """
    Fit a StandardScaler on a 3D tensor (trials, neurons, time)
    by flattening (trials, neurons*time).
    """
    N, I, T = X_train.shape
    X_flat = X_train.reshape(N, I * T)
    scaler = StandardScaler()
    scaler.fit(X_flat)
    return scaler


def transform_tensor_with_scaler(X, scaler):
    """
    Apply a previously-fitted StandardScaler to a 3D tensor.
    """
    N, I, T = X.shape
    X_flat = X.reshape(N, I * T)
    X_flat_std = scaler.transform(X_flat)
    return X_flat_std.reshape(N, I, T)

# ===========================
# CP-logistic model
# ===========================
class CPLogisticRegression(nn.Module):
    """
    CP-decomposed logistic model.

    For each trial n:
      - X_n: (I, J) neuron x time
      - Rank-R CP factors A (I x R), B (J x R)
      - Features z_n,r = a_r^T X_n b_r
      - Logits = Z @ W_class + bias
    """

    def __init__(self, I, J, R, C):
        super().__init__()
        self.A = nn.Parameter(torch.randn(I, R) * 0.1)
        self.B = nn.Parameter(torch.randn(J, R) * 0.1)
        self.W_class = nn.Parameter(torch.randn(R, C) * 0.1)
        self.bias = nn.Parameter(torch.zeros(C))

    def features(self, X):
        """
        X: (N, I, J)
        B: (J, R)
        A: (I, R)

        Compute:
          XB = X ×_time B  -> (N, I, R)
          Z  = sum_i XB_i * A_i -> (N, R)
        """
        XB = torch.tensordot(X, self.B, dims=([2], [0]))  # (N, I, R)
        Z  = (XB * self.A.unsqueeze(0)).sum(dim=1)        # (N, R)
        return Z

    def forward(self, X):
        Z = self.features(X)
        logits = Z @ self.W_class + self.bias
        return logits


def _cp_gauge_fix(model: CPLogisticRegression):
    """
    Normalize columns of A and B, absorb norms into W_class for readability
    and numerical stability.
    """
    with torch.no_grad():
        An = torch.linalg.norm(model.A, dim=0).clamp_min(1e-8)
        Bn = torch.linalg.norm(model.B, dim=0).clamp_min(1e-8)
        scale = An * Bn
        model.A /= An
        model.B /= Bn
        model.W_class *= scale.unsqueeze(1)


@dataclass
class CPTrainConfig:
    rank: int = 3
    lr: float = 1e-3
    weight_decay: float = 5e-4
    batch_size: int = 64
    max_epochs: int = 300
    patience: int = 30
    seed: int = 0


def train_cp_model(X_train, y_train, X_val, y_val, cfg: CPTrainConfig, verbose=False):
    """
    Train CP-logistic with early stopping on validation loss.

    X_train, X_val: (N, I, J), standardized.
    y_train, y_val: 1D arrays.
    """
    set_seed(cfg.seed)
    I, J = X_train.shape[1], X_train.shape[2]
    C = len(np.unique(y_train))

    model = CPLogisticRegression(I, J, cfg.rank, C).to(DEVICE)
    opt = torch.optim.Adam(
        model.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )
    loss_fn = nn.CrossEntropyLoss()

    Xtr = torch.tensor(X_train, dtype=torch.float32, device=DEVICE)
    ytr = torch.tensor(y_train, dtype=torch.long, device=DEVICE)
    Xva = torch.tensor(X_val, dtype=torch.float32, device=DEVICE)
    yva = torch.tensor(y_val, dtype=torch.long, device=DEVICE)

    best_val = float("inf")
    best_state = None
    no_improve = 0

    idx_all = np.arange(len(Xtr))

    for epoch in range(1, cfg.max_epochs + 1):
        model.train()
        np.random.shuffle(idx_all)
        for s in range(0, len(idx_all), cfg.batch_size):
            sel = idx_all[s:s+cfg.batch_size]
            xb, yb = Xtr[sel], ytr[sel]
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            _cp_gauge_fix(model)

        # validation
        model.eval()
        with torch.no_grad():
            vloss = loss_fn(model(Xva), yva).item()

        if verbose and (epoch == 1 or epoch % 10 == 0):
            print(f"  epoch {epoch:03d}  val_loss={vloss:.4f}")

        if vloss + 1e-7 < best_val:
            best_val = vloss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= cfg.patience:
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})

    model.eval()
    with torch.no_grad():
        logits_val = model(Xva)
        val_pred = logits_val.argmax(dim=1).cpu().numpy()
        val_acc = (val_pred == y_val).mean()

    return model, float(val_acc)


def predict_cp(model, X):
    """Predict class labels from CP model. X: (N, I, J), standardized."""
    with torch.no_grad():
        X_t = torch.tensor(X, dtype=torch.float32, device=DEVICE)
        logits = model(X_t)
        pred = logits.argmax(dim=1).cpu().numpy()
    return pred

# ===========================
# Weight-matrix computation & plotting
# ===========================
def compute_cp_weight_matrices(model):
    """
    From a trained CPLogisticRegression model, build the effective
    neuron×time weight matrix for each class:

      W_c(n,t) = sum_r W_class[r,c] * A[n,r] * B[t,r]

    Returns
    -------
    W_classes : array, shape (C, N, T)
    """
    A = model.A.detach().cpu().numpy()         # (N, R)
    B = model.B.detach().cpu().numpy()         # (T, R)
    W_class = model.W_class.detach().cpu().numpy()  # (R, C)

    N, R = A.shape
    T = B.shape[0]
    C = W_class.shape[1]

    W_classes = np.zeros((C, N, T), dtype=float)
    for c in range(C):
        for r in range(R):
            W_classes[c] += W_class[r, c] * np.outer(A[:, r], B[:, r])
    return W_classes


def plot_cp_weight_matrices(W0, W1, neuron_indices, title_prefix="CP weights"):
    """
    W0, W1 : (N, T) matrices for class 0 and class 1
    neuron_indices: global neuron indices (length N)
    """
    Wdiff = W1 - W0
    N, T = W0.shape

    all_vals = np.stack([W0, W1, Wdiff])
    vmax = np.max(np.abs(all_vals)) + 1e-12

    fig = plt.figure(figsize=(14, 4))
    # 4 columns: class0, class1, diff, colorbar
    gs = fig.add_gridspec(1, 4, width_ratios=[1.0, 1.0, 1.0, 0.05])

    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[0, 1])
    ax2 = fig.add_subplot(gs[0, 2])
    cax = fig.add_subplot(gs[0, 3])

    mats = [W0, W1, Wdiff]
    titles = [f"{title_prefix}: class 0",
              f"{title_prefix}: class 1",
              f"{title_prefix}: class 1 - class 0"]

    for ax, mat, title in zip([ax0, ax1, ax2], mats, titles):
        im = ax.imshow(mat, aspect="auto", origin="lower",
                       vmin=-vmax, vmax=vmax, cmap="bwr")
        ax.set_title(title)
        ax.set_xlabel("Time bin")
        ax.set_yticks(np.arange(N))
        ax.set_yticklabels(neuron_indices)

    ax0.set_ylabel("Neuron (global index)")

    cb = fig.colorbar(im, cax=cax)
    cb.set_label("Weight")

    plt.tight_layout()
    plt.show()

    return Wdiff



def plot_lr_weights_bar(w, neuron_indices):
    """
    Plot LR weights for the selected neurons as a simple bar plot.

    w: shape (N,), interpreted as weights for class 1 vs class 0.
    neuron_indices: global neuron indices of length N.
    """
    N = len(w)
    x = np.arange(N)

    plt.figure(figsize=(4, 3))
    plt.bar(x, w)
    plt.axhline(0.0, color="k", linewidth=0.8)
    plt.xticks(x, neuron_indices)
    plt.xlabel("Neuron (global index)")
    plt.ylabel("LR weight (class 1 vs 0)")
    plt.title("Time-averaged LR weights")
    plt.tight_layout()
    plt.show()



def plot_cp_diff_vs_lr_heatmaps(W_diff, W_lr_mat, neuron_indices, title_prefix="CP vs LR"):
    """
    Side-by-side heatmaps:
      - left: CP W_diff (class1 - class0)
      - right: LR weights broadcast across time

    Uses a dedicated axis for the colorbar to avoid overlap.
    """
    N, T = W_diff.shape
    # Shared color scale
    all_vals = np.stack([W_diff, W_lr_mat])
    vmax = np.max(np.abs(all_vals)) + 1e-12

    fig = plt.figure(figsize=(11, 4))
    # 3 columns: CP, LR, colorbar
    gs = fig.add_gridspec(1, 3, width_ratios=[1.0, 1.0, 0.05])

    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[0, 1])
    cax = fig.add_subplot(gs[0, 2])

    # CP W_diff
    im0 = ax0.imshow(W_diff, aspect="auto", origin="lower", vmin=-vmax, vmax=vmax, cmap="bwr")
    ax0.set_title(f"{title_prefix}CP W_diff (class 1 - 0)")
    ax0.set_xlabel("Time bin")
    ax0.set_ylabel("Neuron (global index)")
    ax0.set_yticks(np.arange(N))
    ax0.set_yticklabels(neuron_indices)

    # LR weights matrix
    im1 = ax1.imshow(W_lr_mat, aspect="auto", origin="lower",
                     vmin=-vmax, vmax=vmax, cmap="bwr")
    ax1.set_title(f"{title_prefix}LR weights (time-averaged)")
    ax1.set_xlabel("Time bin")
    ax1.set_yticks(np.arange(N))
    ax1.set_yticklabels(neuron_indices)

    # Shared colorbar
    cb = fig.colorbar(im1, cax=cax)
    cb.set_label("Weight")

    plt.tight_layout()
    plt.show()


# ===========================
# MAIN
# ===========================
if __name__ == "__main__":
    set_seed(SEED)

    print("Loading tensor and labels from MAT/HDF5 files...")
    Xraw, yraw = _load_mat_h5py_first(tensor_path, labels_path)

    # Expect Xraw shape (Tr, T, N): trials × time × neurons
    if Xraw.ndim != 3:
        raise ValueError(f"Activity tensor must be 3D; got shape {Xraw.shape}")

    # Convert (Tr, T, N) -> (N, T, Tr) -> (Tr, N, T)
    X = np.transpose(Xraw, (2, 1, 0))  # (N, T, Tr)
    X = np.moveaxis(X, -1, 0)          # (Tr, N, T) = (trials, neurons, time)

    # Labels: flatten and encode to 0..C-1
    y = np.array(yraw).ravel()
    le = LabelEncoder()
    y = le.fit_transform(y)

    X = X.astype(np.float32)
    y = y.astype(int)

    print(f"Loaded tensor: X.shape={X.shape} (trials, neurons, time)")
    print(f"Loaded labels: y.shape={y.shape}, classes={np.unique(y)}")

    # Check binary
    n_classes = len(np.unique(y))
    if n_classes != 2:
        raise ValueError(f"This script assumes 2 classes; found {n_classes}.")

    # Class balance
    classes, counts = np.unique(y, return_counts=True)
    print("Class counts:", dict(zip(classes, counts)))
    majority = counts.max() / counts.sum()
    print(f"Majority baseline accuracy: {majority*100:.2f}%")

    # ===========================
    # Load neuron scores and build WORST-first order
    # ===========================
    print(f"\nLoading neuron scores from: {scores_path}")
    scores = np.load(scores_path)   # shape (n_neurons,)
    n_trials, n_neurons, T = X.shape
    if scores.shape[0] != n_neurons:
        raise ValueError(
            f"Scores length {scores.shape[0]} does not match n_neurons={n_neurons}"
        )

    # Assume higher score = better neuron (e.g. univariate accuracy).
    # For WORST-first, we sort ascending (lowest score first).
    worst_order = np.argsort(scores)

    print("Neuron order: worst -> best according to loaded scores.")
    print("Global indices of worst neurons:", worst_order[:10])

    # ===========================
    # Select the WORST neurons
    # ===========================
    k_use = min(NUM_NEURONS, n_neurons)
    neuron_idx = worst_order[:k_use]
    print(f"\nUsing NUM_NEURONS={NUM_NEURONS}, effective k_use={k_use}.")
    print("Selected neuron indices (worst):", neuron_idx)

    # Subset tensor and time-averaged features
    X_sub = X[:, neuron_idx, :]        # (trials, k_use, T)
    Xavg_sub = time_average(X_sub)     # (trials, k_use)

    # ===========================
    # Create shared train/val/test splits
    # ===========================
    idx_all = np.arange(n_trials)
    # First: train_full vs test
    idx_train_full, idx_test = train_test_split(
        idx_all,
        test_size=TEST_FRAC,
        stratify=y,
        random_state=SEED,
    )
    # Second: train vs val (within train_full)
    idx_train, idx_val = train_test_split(
        idx_train_full,
        test_size=VAL_FRAC,
        stratify=y[idx_train_full],
        random_state=SEED,
    )

    print(f"\nTrials: total={n_trials}, train={len(idx_train)}, "
          f"val={len(idx_val)}, test={len(idx_test)}")

    # Build CP tensors for those splits
    X_train_full = X_sub[idx_train_full]   # for scaler
    X_train      = X_sub[idx_train]
    X_val        = X_sub[idx_val]
    X_test       = X_sub[idx_test]

    y_train      = y[idx_train]
    y_val        = y[idx_val]
    y_test       = y[idx_test]

    # Standardize CP input (fit on train_full, apply everywhere)
    scaler_cp = fit_scaler_on_tensor(X_train_full)
    X_train_std = transform_tensor_with_scaler(X_train, scaler_cp)
    X_val_std   = transform_tensor_with_scaler(X_val,   scaler_cp)
    X_test_std  = transform_tensor_with_scaler(X_test,  scaler_cp)

    # ===========================
    # Train CP model
    # ===========================
    cp_cfg = CPTrainConfig(
        rank=CP_RANK,
        lr=CP_LR,
        weight_decay=CP_WEIGHT_DECAY,
        batch_size=CP_BATCH_SIZE,
        max_epochs=CP_MAX_EPOCHS,
        patience=CP_PATIENCE,
        seed=SEED,
    )
    print("\nTraining CP-logistic model...")
    cp_model, val_acc = train_cp_model(
        X_train_std, y_train, X_val_std, y_val, cp_cfg, verbose=False
    )
    print(f"CP: inner val acc = {val_acc*100:.2f}%")

    y_pred_cp = predict_cp(cp_model, X_test_std)
    acc_cp = accuracy_score(y_test, y_pred_cp)
    print(f"CP: test acc      = {acc_cp*100:.2f}%")

    # ===========================
    # Compute & plot CP weight matrices
    # ===========================
    W_classes = compute_cp_weight_matrices(cp_model)   # (2, k_use, T)
    W0 = W_classes[0]
    W1 = W_classes[1]

    # plot per-class and diff; also grab W_diff
    W_diff = plot_cp_weight_matrices(W0, W1, neuron_idx,
                                     title_prefix="CP neuron×time weights")

    # ===========================
    # Train LR on time-averaged features (same splits)
    # ===========================
    print("\nTraining time-averaged Logistic Regression...")

    Xavg_train_full = Xavg_sub[idx_train_full]        # for scaler
    Xavg_test       = Xavg_sub[idx_test]
    y_train_full_lr = y[idx_train_full]
    y_test_lr       = y[idx_test]

    scaler_lr = StandardScaler()
    scaler_lr.fit(Xavg_train_full)
    Xavg_train_full_std = scaler_lr.transform(Xavg_train_full)
    Xavg_test_std       = scaler_lr.transform(Xavg_test)

    lr_clf = LogisticRegression(
        C=LR_C,
        penalty=LR_PENALTY,
        solver=LR_SOLVER,
        max_iter=LR_MAX_ITER,
    )
    lr_clf.fit(Xavg_train_full_std, y_train_full_lr)
    y_pred_lr = lr_clf.predict(Xavg_test_std)
    acc_lr = accuracy_score(y_test_lr, y_pred_lr)
    print(f"LR: test acc      = {acc_lr*100:.2f}%")

    # LR coefficients: for binary 'ovr', coef_.shape = (1, k_use)
    w_lr = lr_clf.coef_.ravel()  # class 1 vs class 0

    print("\nLR weights (class 1 vs 0) for selected neurons:")
    for idx, w in zip(neuron_idx, w_lr):
        print(f"  neuron {idx}: {w:.4f}")

    # bar plot (optional)
    plot_lr_weights_bar(w_lr, neuron_idx)

    # ===========================
    # Build LR "weight matrix" comparable to W_diff
    # ===========================
    # Broadcast the per-neuron LR weights across time bins
    W_lr_mat = np.tile(w_lr[:, np.newaxis], (1, T))  # shape (k_use, T)

    # Side-by-side heatmap: CP W_diff vs LR
    plot_cp_diff_vs_lr_heatmaps(W_diff, W_lr_mat, neuron_idx,
                                title_prefix="")

    print("\nDone.")
