%% ===== Rank checks for CP (neurons x time x trials) pairwise medoid =====
% Inputs in workspace:
%   activity_tensor : N x T x K  (double)
%   labels          : K x 1       (kept for compatibility; not used below)
%clearvars -except activity_tensor labels

% --- sanity
assert(ndims(activity_tensor)==3, 'Expected 3D array (neurons x time x trials).');

% --- reproducibility
BASE_SEED = 12345;
rng(BASE_SEED,'twister');   % global, deterministic

% --- config
R_range       = 2:10;
n_restarts    = 50;          % increase if you want (e.g., 50)
solver        = 'cp_nmu';    % 'cp_nmu' (nonneg) or 'cp_als' (unconstrained)
opts          = struct('maxiters',1000,'tol',1e-7,'printitn',100);

% parallel (optional)
if isempty(gcp('nocreate')), try parpool; catch, end, end

% ----- containers
train_fit_best   = zeros(size(R_range));
train_fit_mean   = zeros(size(R_range));
train_fit_std    = zeros(size(R_range));

% stability (FMS to medoid) stats per rank
fms_mean         = zeros(size(R_range));
fms_median       = zeros(size(R_range));
fms_std          = zeros(size(R_range));
fms_min          = zeros(size(R_range));
fms_p25          = zeros(size(R_range));
fms_p75          = zeros(size(R_range));
fms_max          = zeros(size(R_range));

% collinearity / duplicate / residual diagnostics on medoid
collinear_maxA   = zeros(size(R_range));
collinear_medA   = zeros(size(R_range));
collinear_maxB   = zeros(size(R_range));
collinear_medB   = zeros(size(R_range));
collinear_maxC   = zeros(size(R_range));
collinear_medC   = zeros(size(R_range));
dup_score        = zeros(size(R_range));
res_frac1_mode1  = zeros(size(R_range));
res_frac1_mode2  = zeros(size(R_range));
res_frac1_mode3  = zeros(size(R_range));
res_time_autoc   = zeros(size(R_range));

% keep representatives & top-10 per rank
medoid_per_R     = cell(size(R_range));          % ktensor medoid
best10_per_R     = cell(size(R_range));          % 1x<=10 cell of ktensors
best10_fits      = cell(size(R_range));          % corresponding fits
FMS_matrix       = cell(size(R_range));          % pairwise FMS (n_restarts x n_restarts)
FMS_to_medoid    = cell(size(R_range));          % vector (n_restarts x 1), includes 1 for self

fprintf('\n==== Running rank diagnostics (no CV; pairwise medoid) ====\n');
Xfull = Xf_t;%tensor(activity_tensor);

for ir = 1:numel(R_range)
    R = R_range(ir);
    fprintf('\n=== R = %d ===\n', R);

    % ---------- MULTI-RESTART on FULL data ----------
    Pruns = cell(1,n_restarts);
    fits  = -inf(1,n_restarts);

    % derive deterministic per-restart seeds for this rank
    restart_seeds = BASE_SEED + uint32(100000*R) + uint32(1:n_restarts);

    parfor r = 1:n_restarts
        rng(restart_seeds(r), 'twister');                 % <-- per-worker deterministic seed
        P = fit_cp_with_solver_simple(Xfull, R, solver, opts); %#ok<PFBNS>
        Pruns{r} = P;
        fits(r)  = tt_fit_simple(Xfull, P);
    end

    % store train fit summary
    train_fit_best(ir) = max(fits);
    train_fit_mean(ir) = mean(fits);
    train_fit_std(ir)  = std(fits);

    % keep top-10 by training fit
    [~, ord] = sort(fits, 'descend');
    topk = ord(1:min(10, n_restarts));
    best10_per_R{ir} = Pruns(topk);
    best10_fits{ir}  = fits(topk);

    % ---------- EXHAUSTIVE PAIRWISE FMS & TRUE MEDOID ----------
    F = fms_pairwise_matrix(Pruns);                % n_restarts x n_restarts, F_ii = 1
    FMS_matrix{ir} = F;

    avgF = mean(F, 2);                             % average FMS to all others
    [~, med_idx] = max(avgF);
    P_med = Pruns{med_idx};
    medoid_per_R{ir} = P_med;

    f_med = F(:, med_idx);                         % includes 1 for self
    FMS_to_medoid{ir} = f_med;

    % stability stats (exclude self when summarizing)
    f_excl = f_med; f_excl(med_idx) = [];
    fms_mean(ir)   = mean(f_excl);
    fms_median(ir) = median(f_excl);
    fms_std(ir)    = std(f_excl);
    fms_min(ir)    = min(f_excl);
    fms_p25(ir)    = prctile(f_excl,25);
    fms_p75(ir)    = prctile(f_excl,75);
    fms_max(ir)    = max(f_excl);

    % ---------- DIAGNOSTICS on MEDOID ----------
    [A,B,C] = deal( double(P_med.U{1}), double(P_med.U{2}), double(P_med.U{3}) );
    [collinear_maxA(ir), collinear_medA(ir)] = intra_mode_collinearity_simple(A);
    [collinear_maxB(ir), collinear_medB(ir)] = intra_mode_collinearity_simple(B);
    [collinear_maxC(ir), collinear_medC(ir)] = intra_mode_collinearity_simple(C);
    dup_score(ir) = duplicate_component_score_simple(P_med);  % larger => more duplicates

    % residual analysis
    Res = double(Xfull) - double(full(P_med));
    [f1a,f1b,f1c] = residual_topsv_fraction_simple(Res);
    res_frac1_mode1(ir) = f1a;
    res_frac1_mode2(ir) = f1b;
    res_frac1_mode3(ir) = f1c;
    res_time_autoc(ir)  = residual_time_autocorr_simple(Res);
end

%% ---------- PLOTS ----------
% 1) Fit vs R with variance (restarts)
figure('Name','Training fit across restarts','Color','w');
hold on;
errorbar(R_range, train_fit_mean, train_fit_std, '-o','LineWidth',1.5);
plot(R_range, train_fit_best, '--s','LineWidth',1.2);
grid on; xlabel('Rank R'); ylabel('Fit (1 - ||X - \hat{X}||_F / ||X||_F)');
legend('mean \pm std (restarts)', 'best restart', 'Location','southeast');
title('Training fit vs rank (across restarts)');

% 2) Stability across restarts (FMS to medoid): mean ± std + median
figure('Name','Stability across restarts (FMS to medoid)','Color','w');
hold on;
errorbar(R_range, fms_mean, fms_std, '-o','LineWidth',1.5);
plot(R_range, fms_median, '--s','LineWidth',1.2);
grid on; xlabel('Rank R'); ylabel('FMS to medoid');
legend('mean \pm std', 'median', 'Location','southeast');
title('Pairwise-FMS medoid stability');

% 3) Collinearity diagnostics
figure('Name','Collinearity diagnostics','Color','w');
tiledlayout(2,3,'Padding','compact','TileSpacing','compact');
nexttile; plot(R_range,collinear_maxA,'-o'); grid on; title('A (neurons): max cos');
nexttile; plot(R_range,collinear_maxB,'-o'); grid on; title('B (time): max cos');
nexttile; plot(R_range,collinear_maxC,'-o'); grid on; title('C (trials): max cos');
nexttile; plot(R_range,collinear_medA,'-o'); grid on; title('A (neurons): median cos');
nexttile; plot(R_range,collinear_medB,'-o'); grid on; title('B (time): median cos');
nexttile; plot(R_range,collinear_medC,'-o'); grid on; title('C (trials): median cos');

% 4) Duplicate-components score
figure('Name','Duplicate-components score','Color','w');
plot(R_range, dup_score,'-o','LineWidth',1.5); grid on;
xlabel('Rank R'); ylabel('Duplicate score (tri-mode off-diag sim)');
title('Higher suggests over-factoring');

% 5) Residual structure
figure('Name','Residual structure','Color','w');
tiledlayout(2,2,'Padding','compact','TileSpacing','compact');
nexttile; plot(R_range,res_frac1_mode1,'-o'); grid on; title('Residual: top SV frac (mode-1)');
nexttile; plot(R_range,res_frac1_mode2,'-o'); grid on; title('Residual: top SV frac (mode-2)');
nexttile; plot(R_range,res_frac1_mode3,'-o'); grid on; title('Residual: top SV frac (mode-3)');
nexttile; plot(R_range,res_time_autoc,'-o'); grid on; title('Residual: avg lag-1 autocorr (time)');

%% ---------- Console summary ----------
[~, ix_best_train] = max(train_fit_best);
fprintf('\n==== Summary (no CV) ====\n');
fprintf('Best training fit: %.4f at R=%d\n', train_fit_best(ix_best_train), R_range(ix_best_train));
for ir = 1:numel(R_range)
    R = R_range(ir);
    fprintf('R=%d | fit: mean=%.4f std=%.4f best=%.4f | FMS-to-medoid: median=%.3f IQR=[%.3f, %.3f]\n', ...
        R, train_fit_mean(ir), train_fit_std(ir), train_fit_best(ir), fms_median(ir), fms_p25(ir), fms_p75(ir));
end

%% ======================= Local functions (must be at end) =======================
function P = fit_cp_with_solver_simple(Xtrain, R, solver, opts)
    switch lower(solver)
        case 'cp_nmu'
            P = cp_nmu(Xtrain, R, opts);
        case 'cp_als'
            P = cp_als(Xtrain, R, 'tol',opts.tol, 'maxiters',opts.maxiters, 'printitn',opts.printitn);
        otherwise
            error('Unknown solver: %s', solver);
    end
end

function fit = tt_fit_simple(X, P)
    fit = 1 - (norm(double(X - full(P)),'fro') / norm(double(X),'fro'));
end

function Y = norm_cols_simple(A)
    Y = A ./ max(eps, vecnorm(A));
end

function F = fms_pairwise_matrix(Pruns)
% Full pairwise Factor Match Score between all restarts (O(S^2)).
    M = numel(Pruns);
    F = eye(M);
    for i = 1:M
        for j = i+1:M
            F(i,j) = fms_cp_simple(Pruns{i}, Pruns{j});
            F(j,i) = F(i,j);
        end
    end
end

function s = fms_cp_simple(P1, P2)
% Factor Match Score via multiplicative per-mode cosines + optimal assignment
    R = length(P1.lambda);
    U1 = cellfun(@(U) norm_cols_simple(double(U)), P1.U, 'uni', 0);
    U2 = cellfun(@(U) norm_cols_simple(double(U)), P2.U, 'uni', 0);
    S = ones(R);
    for m = 1:numel(U1)
        S = S .* abs(U1{m}' * U2{m});   % similarities in [0,1]
    end
    % Solve assignment
    if exist('matchpairs','file')==2
        [pairs,~] = matchpairs(-S, 0);                  % threshold 0 accepts all
        vals = zeros(1,R);
        [~,ord] = sort(pairs(:,1)); pairs = pairs(ord,:);
        for i = 1:R, vals(i) = S(pairs(i,1), pairs(i,2)); end
        s = mean(vals);
    elseif exist('munkres','file')==2
        idx = munkres(-S); s = mean(S(sub2ind([R R], 1:R, idx)));
    else
        % greedy fallback
        Sg = S; used = false(1,R); sc = zeros(1,R);
        for r=1:R
            [mx, j] = max(Sg(r,:));
            while used(j), Sg(r,j) = -Inf; [mx, j] = max(Sg(r,:)); end
            used(j) = true; sc(r) = mx;
        end
        s = mean(sc);
    end
end

function [max_off, med_off] = intra_mode_collinearity_simple(U)
% Off-diagonal cosine similarities within a single mode
    U = norm_cols_simple(U);
    C = abs(U.'*U);                         % R x R
    C(1:size(C,1)+1:end) = NaN;             % remove diagonal
    off = C(~isnan(C));
    if isempty(off), max_off = 0; med_off = 0; return; end
    max_off = max(off);
    med_off = median(off);
end

function score = duplicate_component_score_simple(P)
% Tri-mode duplicate score: S_ij = prod_m |u_m(:,i)^T u_m(:,j)|
    R = length(P.lambda);
    U = cellfun(@(X) norm_cols_simple(double(X)), P.U, 'uni', 0);
    S = ones(R);
    for m=1:numel(U), S = S .* abs(U{m}'*U{m}); end
    S(1:size(S,1)+1:end) = NaN;             % remove diagonal
    off = S(~isnan(S));
    if isempty(off), score=0; else, score = max(off); end
end

function [frac1_mode1, frac1_mode2, frac1_mode3] = residual_topsv_fraction_simple(Res)
% Residual structure via top singular value fractions of mode matricizations
    N = size(Res,1); T = size(Res,2); K = size(Res,3);
    M1 = reshape(Res, N, T*K);                  % mode-1
    M2 = reshape(permute(Res,[2 1 3]), T, N*K); % mode-2
    M3 = reshape(permute(Res,[3 1 2]), K, N*T); % mode-3
    sv1 = @(M) (svd(M,'econ').^2);
    s1 = sv1(M1); s2 = sv1(M2); s3 = sv1(M3);
    frac1_mode1 = s1(1) / sum(s1);
    frac1_mode2 = s2(1) / sum(s2);
    frac1_mode3 = s3(1) / sum(s3);
end

function r = residual_time_autocorr_simple(Res)
% Mean lag-1 autocorrelation along time of residual across (neurons, trials)
    N = size(Res,1); T = size(Res,2); K = size(Res,3);
    ac = zeros(N,K);
    for n=1:N
        for k=1:K
            x = Res(n,:,k); x = x(:);
            if T<3 || all(x==0), ac(n,k) = 0; continue; end
            xm = x - mean(x);
            num = sum( xm(1:end-1).*xm(2:end) );
            den = sum( xm.^2 );
            ac(n,k) = (den>0) * (num/den);
        end
    end
    r = mean(ac(:),'omitnan');
end
