function visualize_components_2_4_5_6(M, labels, scale_by_lambda, fit_method)
% Four separate figures: components [2, 4, 5, 6]
% Each figure has 2x1 panels: (top) trials scatter + per-class trend lines, (bottom) time profile.
%
% Usage:
%   visualize_components_2_4_5_6(M, labels);                 % robust fit, no lambda scaling
%   visualize_components_2_4_5_6(M, labels, true);           % robust fit, scale by lambda
%   visualize_components_2_4_5_6(M, labels, false, 'polyfit'); % OLS fit
%
% Inputs:
%   M               : ktensor (Tensor Toolbox)
%   labels          : trial labels (1=red, 2=blue; more classes get extra colors)
%   scale_by_lambda : logical (default false)
%   fit_method      : 'robust' (default) or 'polyfit'

    if nargin < 3 || isempty(scale_by_lambda), scale_by_lambda = false; end
    if nargin < 4 || isempty(fit_method),      fit_method      = 'robust'; end

    % Normalize, arrange, fix signs (puts weights in M.lambda)
    M = fixsigns(arrange(normalize(M,0)));
    U   = M.U;
    lam = M.lambda(:);

    % Components we want
    comps = [2 4 5 6];
    R = size(U{1},2);
    if any(comps < 1 | comps > R)
        error('Requested components exceed available range 1..%d', R);
    end

    % Colors: 1=red, 2=blue; extend if more classes exist
    cmap = [1 0 0; 0 0 1];
    umax = max(labels);
    if umax > size(cmap,1)
        extra = lines(umax);
        extra(1:min(2,umax),:) = []; % keep 1,2 as defined
        cmap = [cmap; extra];
    end

    T = numel(labels);
    x_trials = (1:T)';

    for r = comps
        % New figure per component
        f = figure('Color','w','Name',sprintf('Component %d: Trials & Time', r)); %#ok<NASGU>
        tl = tiledlayout(2,1,'Padding','compact','TileSpacing','compact');

        % ================= TOP: TRIALS SCATTER + TREND LINES =================
        ax1 = nexttile(tl,1); hold(ax1,'on');

        y_trials = U{3}(:, r);
        if scale_by_lambda, y_trials = lam(r) * y_trials; end

        ulabels = unique(labels(:))';
        line_col = [0.7 0 0; 0 0 0.7];  % trend line colors for classes 1 and 2

        % scatter per class (clean legend)
        for cls = ulabels
            idx = (labels == cls);
            scatter(ax1, x_trials(idx), y_trials(idx), 40, cmap(cls,:), ...
                'filled', 'MarkerFaceAlpha', 0.9, 'DisplayName', sprintf('Type %d', cls));
        end

        % fit line per class
        for cls = ulabels
            idx = find(labels == cls);
            if numel(idx) >= 2
                x = x_trials(idx);
                y = y_trials(idx);
                switch lower(fit_method)
                    case 'robust'
                        % try robustfit, fallback to polyfit if toolbox missing
                        try
                            b = robustfit(x, y);           % [intercept; slope]
                            yhat = b(1) + b(2) * x_trials;
                        catch
                            p = polyfit(x, y, 1);          % [slope intercept]
                            yhat = polyval(p, x_trials);
                        end
                    case 'polyfit'
                        p = polyfit(x, y, 1);
                        yhat = polyval(p, x_trials);
                    otherwise
                        error('fit_method must be ''robust'' or ''polyfit''.');
                end
                % choose line color
                if cls <= 2
                    lc = line_col(cls,:);
                else
                    lc = [0 0 0]; % default for extra classes
                end
                plot(ax1, x_trials, yhat, 'LineWidth', 2, 'Color', lc, ...
                    'DisplayName', sprintf('Trend %d', cls));
            end
        end

        grid(ax1,'on'); box(ax1,'on');
        xlabel(ax1,'Trial');
        if scale_by_lambda
            ylabel(ax1,'\lambda \cdot Weight','Interpreter','tex');
        else
            ylabel(ax1,'Weight');
        end
        title(ax1, sprintf('Component %d – Trials (\\lambda=%.3g)', r, lam(r)));
        legend(ax1, 'Location','best');

        % ================= BOTTOM: TIME PROFILE =================
        ax2 = nexttile(tl,2); hold(ax2,'on');
        y_time = U{2}(:, r);
        if scale_by_lambda, y_time = lam(r) * y_time; end
        plot(ax2, y_time, 'LineWidth', 1.5, 'Color', [0 0.45 0.74]);
        grid(ax2,'on'); box(ax2,'on');
        xlabel(ax2,'Time');
        if scale_by_lambda
            ylabel(ax2,'\lambda \cdot Weight','Interpreter','tex');
        else
            ylabel(ax2,'Weight');
        end
        title(ax2, sprintf('Component %d – Time', r));

        % Figure-level title
        if scale_by_lambda
            title(tl, sprintf('Component %d – Trials (top) & Time (bottom) [scaled by \\lambda], fit=%s', r, fit_method));
        else
            title(tl, sprintf('Component %d – Trials (top) & Time (bottom), fit=%s', r, fit_method));
        end
    end
end

