% ================================
% Per-neuron mean & std, quantile cleaning, and plots
% ================================

%% ---- params ----
q_lo = 2.5;          % lower percentile for mean-based trimming
q_hi = 97.5;         % upper percentile
nbins_mu  = 50;      % histogram bins for means
nbins_std = 50;      % histogram bins for stds
order_by_mean = true;

%% ---- data to double ----
Xd = double(X);                          % works if X is TT tensor
[N, T, Tr] = size(Xd);

%% ---- per-neuron stats over the whole experiment (time × trials) ----
% reshape each neuron to a vector of length (T*Tr)
Xr = reshape(Xd, N, T*Tr);               % N x (T*Tr)

% guard NaNs if present
mu  = mean(Xr, 2, 'omitnan');            % N x 1  (mean over all samples)
sd  = std(Xr, 0, 2, 'omitnan');          % N x 1  (std over all samples)

%% ---- quantile cleaning on neuron means ----
qlo = prctile(mu, q_lo);
qhi = prctile(mu, q_hi);
keep = (mu >= qlo) & (mu <= qhi);        % logical mask of kept neurons
idx_keep = find(keep);

fprintf('Kept %d/%d neurons (%.1f%%). Thresholds: [%.4g, %.4g]\n', ...
    sum(keep), N, 100*mean(keep), qlo, qhi);

%% ---- sort by mean (optional) ----
if order_by_mean
    [mu_s, order] = sort(mu, 'ascend');
    sd_s = sd(order);
    keep_s = keep(order);
else
    mu_s = mu;
    sd_s = sd;
    keep_s = keep;
    order = (1:N).';
end

%% ---- FIGURE 1: means with per-neuron std shadow; overlay kept means ----
figure('Color','w'); hold on;
x = 1:N;

% shaded band: mean ± std (per-neuron)
upper = mu_s + sd_s;
lower = mu_s - sd_s;

% patch (shadow)
px = [x, fliplr(x)];
py = [upper.', fliplr(lower.')];
ph = fill(px, py, [0.6 0.6 0.9], 'FaceAlpha', 0.25, 'EdgeColor', 'none'); %#ok<NASGU>

% base mean curve (all neurons)
plot(x, mu_s, 'Color', [0.2 0.2 0.7], 'LineWidth', 1.2);

% overlay: kept neurons only (same x positions), in orange and thicker
mu_keep_overlay = nan(size(mu_s));
mu_keep_overlay(keep_s) = mu_s(keep_s);
plot(x, mu_keep_overlay, 'Color', [0.95 0.45 0.10], 'LineWidth', 2.0);

% quantile lines (horizontal)
yline(qlo, ':', sprintf('Q_{%.1f} = %.3g', q_lo, qlo), 'Color',[0.3 0.3 0.3]);
yline(qhi, ':', sprintf('Q_{%.1f} = %.3g', q_hi, qhi), 'Color',[0.3 0.3 0.3]);

xlabel('Neurons (sorted by mean)');
ylabel('Mean activity (a.u.)');
title(sprintf('Per-neuron mean with std shadow — kept (orange): %d/%d', sum(keep_s), N));
grid on; box on; xlim([1 N]);

%% ---- FIGURE 2: histogram of per-neuron means ----
figure('Color','w');
histogram(mu, nbins_mu, 'FaceColor',[0.2 0.6 0.9], 'EdgeColor','none'); hold on;
xline(qlo, 'r--', 'LineWidth',1.2);
xline(qhi, 'r--', 'LineWidth',1.2);
xlabel('Per-neuron mean activity');
ylabel('Count');
title(sprintf('Distribution of neuron means (Q_{%.1f}=%.3g, Q_{%.1f}=%.3g)', q_lo, qlo, q_hi, qhi));
grid on; box on;

%% ---- FIGURE 3: histogram of per-neuron stds ----
figure('Color','w');
histogram(sd, nbins_std, 'FaceColor',[0.4 0.4 0.4], 'EdgeColor','none');
xlabel('Per-neuron standard deviation');
ylabel('Count');
title('Distribution of neuron std across time×trials');
grid on; box on;
%% Save the final tensor

Xd = double(X);                 % ensure numeric for indexing
Xf = Xd(results.keep, :, :);    % keep = results.keep from the script

% (optional) wrap for Tensor Toolbox ops
Xf_t = tensor(Xf);              % requires Tensor Toolbox

% --- save to disk (use -v7.3 for big arrays) ---
save('tensor_cleaned.mat', 'Xf', 'Xf_t', ...
     'results', 'mu', 'sd', 'idx_keep', '-v7.3');
%% Load
S = load('tensor_cleaned.mat');  % gives S.Xf, S.Xf_t, S.results, ...
