function fig = wavelet_decomp_piled(x, fs, xa, xb, J, waveletName, topK)
% WAVELET_DECOMP_PILED (rev)
%   Figura "piled":
%     • riga 1: segnale
%     • righe successive: D1..D4 (se esistono) + le 4 bande con più energia (senza duplicati; può includere A_J)
%     • riga finale: heatmap 1×(J+1) con la frazione di energia per TUTTE le bande (D1..DJ, A_J)
%
%   La frazione di energia è normalizzata sulla somma delle energie delle bande
%   (sum_k ||m_k||^2), così la somma fa sempre 100% anche con MODWT.
%
% Uso:
%   fig = wavelet_decomp_piled(x, fs, xa, xb);                 % default J, 'db3', topK=4
%   fig = wavelet_decomp_piled(x, fs, xa, xb, 16, 'db3', 4);   % esplicito
%
% Note:
%   - xa/xb possono essere secondi oppure indici (se interi in [1,T] → interpretati come indici).
%   - Richiede Wavelet Toolbox (modwt, modwtmra).

    % ---------- default ----------
    x = x(:).';                        % riga
    T = numel(x);
    if nargin < 5 || isempty(J),            J = max(1, min(16, floor(log2(T)) - 1)); end
    if nargin < 6 || isempty(waveletName),  waveletName = 'db3'; end
    if nargin < 7 || isempty(topK),         topK = 4; end

    t = (0:T-1)/fs;

    % ---------- interpreta xa/xb: indici interi -> indici; altrimenti secondi ----------
    useIdx = false;
    if nargin >= 4 && ~isempty(xa) && ~isempty(xb)
        if isscalar(xa) && isscalar(xb) && xa>=1 && xb<=T && xa<xb ...
           && abs(xa-round(xa))<eps && abs(xb-round(xb))<eps
            useIdx = true;
        end
    end
    if nargin < 3 || isempty(xa), xa = useIdx*1 + (~useIdx)*t(1); end
    if nargin < 4 || isempty(xb), xb = useIdx*T + (~useIdx)*t(end); end
    if useIdx
        xa = max(1, min(T, round(xa)));
        xb = max(1, min(T, round(xb)));
        if xb <= xa, xa = 1; xb = T; end
        xlim_sec = ([xa xb]-1)/fs;
    else
        xa = max(t(1), min(t(end), xa));
        xb = max(t(1), min(t(end), xb));
        if xb <= xa, xa = t(1); xb = t(end); end
        xlim_sec = [xa xb];
    end

    % ---------- preprocess ----------
    m = mean(x(~isnan(x))); if isnan(m), m = 0; end
    x = x - m; x(isnan(x)) = 0;

    % ---------- MODWT & MRA ----------
    wt  = modwt(x, waveletName, J);      % (J+1) × T
    mra = modwtmra(wt, waveletName);     % (J+1) × T  (righe 1..J = D1..DJ, riga J+1 = A_J)

    % ---------- energie ----------
    bandE    = sum(mra.^2, 2);           % energia per banda
    bandsum  = sum(bandE);               % somma energie bande (usata per avere 100%)
    if bandsum <= eps, bandsum = eps; end
    Efrac    = bandE / bandsum;          % somma = 1
    % (se ti serve anche rispetto a ||x||^2: Efrac_sig = bandE / max(sum(x.^2), bandsum);)

    % ---------- selezione righe da plottare ----------
    baseIdx = 1:min(4, J);               % D1..D4 (se esistono)
    [~, sortedIdx] = sort(Efrac, 'descend');     % include A_J (indice J+1)
    sel = baseIdx;                               
    for k = 1:min(topK, numel(sortedIdx))
        idx = sortedIdx(k);
        if ~ismember(idx, sel)
            sel(end+1) = idx; %#ok<AGROW>
        end
    end
    % mantieni D1..D4 in testa, ordina gli extra per numero di livello
    if numel(sel) > numel(baseIdx)
        extra = sel(numel(baseIdx)+1:end);
        extra = sort(extra);
        sel   = [baseIdx, extra];
    end

    % ---------- etichette ----------
    detailLabels = arrayfun(@(j) sprintf('D%d', j), 1:J, 'UniformOutput', false);
    allLabels    = [detailLabels, {sprintf('A_{%d}', J)}];

    % ---------- figura ----------
    nRows = 1 + numel(sel) + 1;          % segnale + selezionate + heatmap
    fig = figure('Color','w','Name','Piled wavelet view (D1..D4 + top-energy)');
    tl  = tiledlayout(nRows, 1, 'TileSpacing','compact', 'Padding','compact');

    % riga 1: segnale
    ax1 = nexttile(tl, 1);
    plot(t, x, 'LineWidth', 1.2);
    grid on; xlim(xlim_sec);
    ylabel('Signal');
    title(sprintf('Signal (fs = %.3f Hz), MODWT(%s), J = %d', fs, waveletName, J), 'Interpreter','none');

    % righe 2..: componenti selezionate
    axSel = gobjects(numel(sel),1);
    for r = 1:numel(sel)
        idx = sel(r);
        axSel(r) = nexttile(tl, 1 + r);
        plot(t, mra(idx,:), 'LineWidth', 1.0);
        grid on; xlim(xlim_sec);
        ylabel(allLabels{idx}, 'Interpreter','tex');

        % annotazioni: banda temporale + energia (sommatoria=100%)
        if idx <= J
            [Tlo, Thi] = level_period_band(idx, fs);
            ttl = sprintf('%s  |  T\\in[%.3g, %.3g) s  |  Energy = %.1f%%%%', ...
                          allLabels{idx}, Tlo, Thi, 100*Efrac(idx));
        else
            TloA = approx_period_band(J, fs);
            ttl = sprintf('%s  |  T\\ge %.3g s  |  Energy = %.1f%%%%', ...
                          allLabels{idx}, TloA, 100*Efrac(idx));
        end
        title(axSel(r), ttl, 'Interpreter','tex');

        if r < numel(sel), set(axSel(r), 'XTickLabel', []); end
    end

    % riga finale: heatmap per tutte le bande (D1..DJ, A_J) — include A_J
    axH = nexttile(tl, nRows);
    imagesc(axH, 1:(J+1), 1, Efrac(:).');     % 1×(J+1)
    set(axH, 'YDir','normal');
    colormap(axH, parula(256));
    caxis(axH, [0 1]);                        % scala in [0,1] (somma=1)
    cb = colorbar(axH, 'Location','eastoutside');
    cb.Label.String = 'Energy fraction (sum = 100%)';
    yticks(axH, []); ylim(axH, [0.5 1.5]);
    xticks(axH, 1:(J+1));
    xticklabels(axH, allLabels);
    axH.XAxis.TickLabelInterpreter = 'tex';
    axH.XTickLabelRotation = 45;
    title(axH, 'Per-band energy (D1..DJ, A_{J})', 'Interpreter','tex');

    % link assi x per i grafici tempo-serie
    linkaxes([ax1; axSel(:)], 'x');
end

% ----------------- helpers -----------------
function [Tlo, Thi] = level_period_band(j, fs)
    dt  = 1/fs;
    Tlo = 2^j    * dt;
    Thi = 2^(j+1)* dt;
end

function Tlo = approx_period_band(J, fs)
    dt  = 1/fs;
    Tlo = 2^(J+1) * dt;
end
