%% ===== Plot the x-th best model at a chosen rank =====
% Inputs expected in workspace:
%   R_range        : vector of ranks you evaluated (e.g., 2:10)
%   best10_per_R   : cell array; each {ir} is 1x<=10 cell of ktensors (top by train fit)
%   best10_fits    : (optional) cell array; fits corresponding to best10_per_R
%   labels         : (optional) Kx1 trial-type labels for coloring C

% ---- USER PARAMS ----
R_target  = 5;   % which rank
model_idx = 1;   % x-th best model at that rank (1..<=10)
topkA     = 50;  % how many top-|A| neurons to show per component
sort_by_lambda = true;  % if true, order components by descending lambda

% ---- locate the rank index and grab the model ----
ir = find(R_range == R_target, 1);
assert(~isempty(ir), 'R_target not found in R_range.');
models = best10_per_R{ir};
assert(~isempty(models), 'No stored models for this rank.');
assert(model_idx >= 1 && model_idx <= numel(models), 'model_idx out of bounds for this R.');

P = models{model_idx};
A = double(P.U{1}); B = double(P.U{2}); C = double(P.U{3});
lam = double(P.lambda(:));
R  = numel(lam);
M=P;
% normalize A,B columns; push lambda into C to reflect component strength on trials
normc = @(X) X ./ max(eps, vecnorm(X));
A = normc(A); B = normc(B);
C = C .* lam.';     % scale C(:,r) by lambda(r) for visualization

% optional: reorder components by descending lambda
ord = 1:R;
if sort_by_lambda
    [~, ord] = sort(lam, 'descend');
end

% ---- labels (optional) for coloring C
have_labels = exist('labels','var') && ~isempty(labels);
if have_labels
    labs  = double(labels(:));
    uLabs = unique(labs);
    cmap  = lines(numel(uLabs));
    lab2col = containers.Map(num2cell(uLabs), num2cell(1:numel(uLabs)));
end

% ---- figure layout: 3 rows (A,B,C) × R columns
fig = figure('Name', sprintf('Rank R=%d | %d-th best model | components', R_target, model_idx), ...
             'Color','w');
tl = tiledlayout(3, R, 'Padding','compact', 'TileSpacing','compact');

% annotate with fit if available
if exist('best10_fits','var') && numel(best10_fits) >= ir && numel(best10_fits{ir}) >= model_idx
    fit_m = best10_fits{ir}(model_idx);
    suptitle_str = sprintf('Rank R=%d — Medoid (fit=%.4f) — A(top-%d) / B(time) / C(trials)', ...
                           R_target, fit_m, topkA);
else
    suptitle_str = sprintf('Rank R=%d — %d-th best — A(top-%d) / B(time) / C(trials)', ...
                           R_target, model_idx, topkA);
end

% ---- plot each component
for jj = 1:R
    r = ord(jj);

    % Row 1: A (top-k neurons)
    nexttile(tl, jj);
    ar = A(:, r);
    [~, idx] = sort(abs(ar), 'descend');
    sel = idx(1:min(topkA, numel(ar)));
    stem(1:numel(sel), ar(sel), 'filled');
    grid on; title(sprintf('A_r=%d (top-%d)', r, numel(sel)));
    xlabel('ranked neurons'); ylabel('loading'); xlim([0.5, numel(sel)+0.5]);

    % Row 2: B (time)
    nexttile(tl, R + jj);
    plot(B(:, r), '-', 'LineWidth', 1.5);
    grid on; title(sprintf('B_r=%d', r));
    xlabel('time'); ylabel('loading');

    % Row 3: C (trials)
    nexttile(tl, 2*R + jj);
    cr = C(:, r);
    if have_labels
        hold on;
        for k = 1:numel(cr)
            ci = lab2col(labs(k));
            scatter(k, cr(k), 18, cmap(ci,:), 'filled');
        end
        hold off; grid on;
        title(sprintf('C_r=%d (colored by label)', r));
    else
        stem(1:numel(cr), cr, 'filled'); grid on;
        title(sprintf('C_r=%d', r));
    end
    xlabel('trial'); ylabel('weight');
end

title(tl, suptitle_str);

% ---- optional: save the figure
% outname = sprintf('R%d_model%d_components.png', R_target, model_idx);
% exportgraphics(fig, outname, 'Resolution',300);
