function fig = wavelet_fast_rest(x, fs, xa, xb, J, waveletName, fastLevels)
% WAVELET_FAST_REST (versione con doppia normalizzazione energia)
% Pannelli:
%  (1) Segnale
%  (2) Ricostruzione da fastLevels (default D1..D4)
%  (3) Resto = (tutti gli altri dettagli) + A_J
%
% Mostra in ciascun titolo:
%   - %_band : percentuale rispetto a sum_k ||m_k||^2  (sommatoria = 100%)
%   - %_sig  : percentuale rispetto a ||x||^2          (NON somma a 100% con MODWT)

    % ---------- default ----------
    x = x(:).';
    T = numel(x);
    if nargin < 5 || isempty(J),            J = max(1, min(16, floor(log2(T)) - 1)); end
    if nargin < 6 || isempty(waveletName),  waveletName = 'db3'; end
    if nargin < 7 || isempty(fastLevels),   fastLevels = 1:min(4, J); end
    fastLevels = fastLevels(fastLevels>=1 & fastLevels<=J);
    if isempty(fastLevels), fastLevels = 1:min(4,J); end

    t = (0:T-1)/fs;

    % ---------- xa/xb: indici interi -> indici; altrimenti secondi ----------
    useIdx = false;
    if nargin >= 4 && ~isempty(xa) && ~isempty(xb)
        if isscalar(xa) && isscalar(xb) && xa>=1 && xb<=T && xa<xb ...
           && abs(xa-round(xa))<eps && abs(xb-round(xb))<eps
            useIdx = true;
        end
    end
    if nargin < 3 || isempty(xa), xa = useIdx*1 + (~useIdx)*t(1); end
    if nargin < 4 || isempty(xb), xb = useIdx*T + (~useIdx)*t(end); end
    if useIdx
        xa = max(1,min(T,round(xa))); xb = max(1,min(T,round(xb)));
        if xb<=xa, xa=1; xb=T; end
        xlim_sec = ([xa xb]-1)/fs;
    else
        xa = max(t(1),min(t(end),xa)); xb = max(t(1),min(t(end),xb));
        if xb<=xa, xa=t(1); xb=t(end); end
        xlim_sec = [xa xb];
    end

    % ---------- preprocess ----------
    m = mean(x(~isnan(x))); if isnan(m), m = 0; end
    x = x - m; x(isnan(x)) = 0;

    % ---------- MODWT & MRA ----------
    wt  = modwt(x, waveletName, J);     % (J+1) × T
    mra = modwtmra(wt, waveletName);    % (J+1) × T  (1..J: D1..DJ, J+1: A_J)

    % ---------- energie ----------
    bandE      = sum(mra.^2, 2);        % energie per banda
    bandsum    = sum(bandE);            % somma delle energie MRA
    sigsum     = sum(x.^2);             % energia del segnale
    if sigsum <= eps, sigsum = bandsum + eps; end

    % due viste di frazione d'energia
    Efrac_band = bandE / max(bandsum, eps);   % somma = 1
    Efrac_sig  = bandE / sigsum;              % NON somma a 1 con MODWT

    % ---------- ricostruzioni ----------
    allRows  = 1:(J+1);
    fastRows = fastLevels;                % dettagli selezionati
    restRows = setdiff(allRows, fastRows);

    recFast = sum(mra(fastRows, :), 1);
    recRest = sum(mra(restRows, :), 1);

    Ef_band = sum(Efrac_band(fastRows));
    Er_band = sum(Efrac_band(restRows));
    Ef_sig  = sum(Efrac_sig(fastRows));
    Er_sig  = sum(Efrac_sig(restRows));

    % ---------- figura 3 pannelli ----------
    fig = figure('Color','w','Name','Segnale, D1–D4, Resto');
    tl  = tiledlayout(fig, 3, 1, 'TileSpacing','compact', 'Padding','compact');

    % (1) segnale
    ax1 = nexttile(tl, 1);
    plot(t, x, 'LineWidth', 1.2);
    grid on; xlim(xlim_sec);
    ylabel('Signal');
    title(sprintf('Signal  (fs = %.3f Hz, MODWT(%s), J = %d)', fs, waveletName, J), 'Interpreter','none');

    % (2) fast
    ax2 = nexttile(tl, 2);
    plot(t, recFast, 'LineWidth', 1.2);
    grid on; xlim(xlim_sec);
    if numel(fastLevels)==1
        lbl = sprintf('D%d', fastLevels);
    else
        lbl = sprintf('D%d–D%d', min(fastLevels), max(fastLevels));
    end
    ylabel(lbl);
    title(sprintf('Reconstruction from %s  —  %.1f%% (band-sum), %.1f%% (||x||^2)', ...
                  lbl, 100*Ef_band, 100*Ef_sig));

    % (3) resto (altri dettagli + A_J)
    ax3 = nexttile(tl, 3);
    plot(t, recRest, 'LineWidth', 1.2);
    grid on; xlim(xlim_sec);
    ylabel('Rest');
    minRest = min(setdiff(1:J, fastLevels));
    if isempty(minRest)
        restTxt = sprintf('A_{%d}', J);
    else
        restTxt = sprintf('D%d–D%d + A_{%d}', minRest, J, J);
    end
    title(sprintf('Remainder (%s)  —  %.1f%% (band-sum), %.1f%% (||x||^2)', ...
                  restTxt, 100*Er_band, 100*Er_sig));
    xlabel('Time [s]');

    linkaxes([ax1 ax2 ax3], 'x');
end

% --- helper: periodo per livello dettaglio ---
function [Tlo, Thi] = level_period_band(j, fs)
    dt  = 1/fs;
    Tlo = 2^j    * dt;
    Thi = 2^(j+1)* dt;
end

% --- helper: periodo minimo per A_J ---
function Tlo = approx_period_band(J, fs)
    dt  = 1/fs;
    Tlo = 2^(J+1) * dt;
end
