function overlap_heatmap_top5(M, comps, top_percent, use_abs)
% Heatmap of overlap among top-k% neurons (by loading on each component).
% - Overlap metric: Jaccard index |A∩B| / |A∪B|
% - Cell annotations: "∩=count | J=xx%"
%
% Usage:
%   overlap_heatmap_top5(M);                      % defaults comps=[2 4 5 6], top_percent=5, use_abs=false
%   overlap_heatmap_top5(M, [2 4 5 6], 5, true);  % use absolute loadings

    if nargin < 2 || isempty(comps),       comps = [1 7 8 9]; end
    if nargin < 3 || isempty(top_percent), top_percent = 5;  end
    if nargin < 4 || isempty(use_abs),     use_abs = false;  end

    % Normalize, arrange, fix signs (λ in M.lambda)
    M = fixsigns(arrange(normalize(M,0)));
    U1 = M.U{1};            % neurons × R
    R  = size(U1,2);

    if any(comps < 1 | comps > R)
        error('Component index out of range. Available components: 1..%d', R);
    end

    n_neu   = size(U1,1);
    n_keep  = max(1, round(n_neu * top_percent/100));

    % --- Build top-sets per component ---
    S = cell(numel(comps),1);   % index sets
    for k = 1:numel(comps)
        r = comps(k);
        loadings = U1(:, r);
        if use_abs, loadings = abs(loadings); end
        [~, idx_top] = maxk(loadings, n_keep);
        S{k} = sort(idx_top(:));
    end

    % --- Pairwise Jaccard and counts ---
    K = numel(comps);
    J = zeros(K);   % Jaccard
    I = zeros(K);   % intersection counts
    U = zeros(K);   % union counts (for completeness)

    for i = 1:K
        Ai = S{i};
        for j = 1:K
            Aj = S{j};
            Iij = numel(intersect(Ai, Aj));
            Uij = numel(union(Ai, Aj));
            I(i,j) = Iij;
            U(i,j) = Uij;
            if Uij == 0
                J(i,j) = 0;
            else
                J(i,j) = Iij / Uij;
            end
        end
    end

    % --- Plot heatmap (Jaccard) ---
    figure('Color','w','Name','Top-5% overlap (Jaccard)');
    imagesc(J); axis square;
    colormap(parula); colorbar;
    caxis([0 1]);
    xticks(1:K); yticks(1:K);
    xticklabels(arrayfun(@(x) sprintf('Comp %d', x), comps, 'UniformOutput', false));
    yticklabels(arrayfun(@(x) sprintf('Comp %d', x), comps, 'UniformOutput', false));
    title(sprintf('Overlap of top %.0f%% neurons (Jaccard)', top_percent));

    % Annotate cells with ∩ count and J%
    hold on;
    for i = 1:K
        for j = 1:K
            txt = sprintf('\\cap=%d | J=%2.0f%%', I(i,j), 100*J(i,j));
            text(j, i, txt, 'HorizontalAlignment','center', 'Color','w', 'FontWeight','bold');
        end
    end
    hold off;

    % Also print the sets (optional)
    fprintf('Top %d (%.0f%%) neuron counts per component: %d each\n', n_keep, top_percent, n_keep);
    for k = 1:K
        fprintf('Comp %d: %d neurons in top-set\n', comps(k), numel(S{k}));
    end

    % Return variables if caller wants them
    assignin('base','overlap_Jaccard',J);
    assignin('base','overlap_intersections',I);
    assignin('base','overlap_unions',U);
    assignin('base','overlap_sets',S);
end
