%% ---------- choose which medoid to plot ----------
R_to_plot = 2;                              % <- pick the rank you want
P = medoid_per_R{R_range==R_to_plot};       % ktensor (medoid at that R)
M = medoid_per_R{R_range==R_to_plot}; 

plot_cp_medoid_components(P, labels, ...
    'scaleByLambda', false, ...              % scale factors by lambda
    'topk', 100, ...                         % show top-k neurons per component
    'fitType', 'robust', ...                % 'robust' | 'ols' | 'none'
    'titlePrefix', sprintf('Medoid components  (R = %d)', R_to_plot));

%% ======================= local function (keep at end of file) =======================
function plot_cp_medoid_components(P, labels, varargin)
% Nice, compact visualization of a CP ktensor's components.
% Rows: Time kernels (B), Trial weights (C*lambda) colored by label + trend lines,
%       Top-k neuron loadings (A), sorted by magnitude.
% Options:
%   'scaleByLambda' (true/false), 'topk' (int), 'fitType' ('robust'|'ols'|'none'),
%   'titlePrefix' (str)

    ip = inputParser;
    ip.addParameter('scaleByLambda', true, @islogical);
    ip.addParameter('topk', 50, @(x)isnumeric(x)&&isscalar(x)&&x>0);
    ip.addParameter('fitType', 'robust', @(s)ischar(s)||isstring(s));
    ip.addParameter('titlePrefix', '', @(s)ischar(s)||isstring(s));
    ip.parse(varargin{:});
    S = ip.Results;

    A = double(P.U{1}); B = double(P.U{2}); C = double(P.U{3});
    lam = double(P.lambda(:));
    [N,T,R] = deal(size(A,1), size(B,1), size(A,2));
    y = labels(:);
    assert(numel(y)==size(C,1), 'labels length must match #trials');

    % scale factors by lambda if requested
    if S.scaleByLambda
        A = A .* lam'; B = B .* lam'; C = C .* lam';
    end

    % prep layout: 3 rows × R columns
    figure('Color','w','Name','CP medoid components');
    tl = tiledlayout(3, R, 'TileSpacing','compact', 'Padding','compact');
    title(tl, S.titlePrefix, 'FontWeight','bold');

    % ---------- row 1: TIME kernels ----------
    for r = 1:R
        nexttile(r);
        plot(1:T, B(:,r), '-o', 'LineWidth', 1.2, 'MarkerSize', 3); grid on;
        xlabel('Time'); ylabel('Weight');
        title(sprintf('Time'));
    end

    % ---------- row 2: TRIAL weights (colored) + optional trend lines ----------
    t = (1:size(C,1))';
    cls = sort(unique(y));
    cmap = lines(numel(cls));
    for r = 1:R
        nexttile(R + r); hold on;
        s = C(:,r);
        for ci = 1:numel(cls)
            mask = (y==cls(ci));
            scatter(t(mask), s(mask), 18, 'MarkerFaceColor', cmap(ci,:), ...
                    'MarkerEdgeColor', 'none', 'MarkerFaceAlpha', .9);
            switch lower(S.fitType)
                case 'robust'
                    if exist('robustfit','file')==2
                        b = robustfit(t(mask), s(mask));
                        plot(t, b(1) + b(2)*t, '-', 'Color', cmap(ci,:), 'LineWidth', 1.2);
                    else
                        p = polyfit(t(mask), s(mask), 1);
                        plot(t, polyval(p, t), '-', 'Color', cmap(ci,:), 'LineWidth', 1.2);
                    end
                case 'ols'
                    p = polyfit(t(mask), s(mask), 1);
                    plot(t, polyval(p, t), '-', 'Color', cmap(ci,:), 'LineWidth', 1.2);
                case 'none'
                    % no line
            end
        end
        grid on; xlabel('Trial'); ylabel('\lambda \cdot C');
        title(sprintf('Trials'));
        if r==R
            lgd = arrayfun(@(c)sprintf('class %g', c), cls, 'uni', 0);
            legend(lgd, 'Location','best', 'Box','off');
        end
        hold off;
    end

    % ---------- row 3: NEURON loadings (top-k) ----------
    k = min(S.topk, N);
    for r = 1:R
        nexttile(2*R + r);
        ar = A(:,r);
        [vals, idx] = maxk(ar, k);
        stem(1:k, vals, 'filled'); grid on;
        xlabel('Top-k neurons'); ylabel('Weight');
        title(sprintf('Neurons, [top-%d]', k));
    end
end
