% ------------------------
% PARAMETERS
% ------------------------
component_idx = 3;     % CP component to analyze
top_percent   = 100;   % top % of neurons to select
arrow_scale   = 0.3;   % shorten arrows by this factor

% ------------------------
% LOADINGS & TOP NEURONS
% ------------------------
U1 = M.U{1};                      % neuron factor matrix
loadings = U1(:, component_idx);  % loadings for this component

n_neurons = size(U1, 1);
n_keep = max(1, round(n_neurons * top_percent / 100));  % always at least one neuron
[~, idx_top] = maxk(loadings, n_keep);                  % indices of top neurons

% ------------------------
% TRIAL LABELS (1=A, 2=B)  <-- (keep consistent with your labels)
% ------------------------
trial_type_A = find(labels == 1);
trial_type_B = find(labels == 2);

if numel(trial_type_A) < 2 || numel(trial_type_B) < 2
    error('You need at least two trials for both A and B types.');
end

% First/last K trials per stimulus (K=5, but clamp to available)
K = 5;
K_Af = min(K, numel(trial_type_A));
K_Al = min(K, numel(trial_type_A));
K_Bf = min(K, numel(trial_type_B));
K_Bl = min(K, numel(trial_type_B));

firstA = trial_type_A(1:K_Af);
lastA  = trial_type_A(end-K_Al+1:end);
firstB = trial_type_B(1:K_Bf);
lastB  = trial_type_B(end-K_Bl+1:end);

% ------------------------
% RESPONSE EXTRACTION (mean over time, then mean across the chosen trials)
% activity_tensor: N x Time x Trials
% ------------------------
% For each neuron: x1 = mean over time of mean across firstA trials (stim A early)
%                  x2 = mean over time of mean across lastA  trials (stim A late)
%                  y1,y2 same for stim B
x1 = squeeze(mean(mean(activity_tensor(idx_top, :, firstA), 2), 3));  % (n_keep x 1)
x2 = squeeze(mean(mean(activity_tensor(idx_top, :, lastA ), 2), 3));
y1 = squeeze(mean(mean(activity_tensor(idx_top, :, firstB), 2), 3));
y2 = squeeze(mean(mean(activity_tensor(idx_top, :, lastB ), 2), 3));

% Ensure column vectors
x1 = x1(:); y1 = y1(:); x2 = x2(:); y2 = y2(:);

% Deltas for arrows (A on x-axis, B on y-axis)
u = (x2 - x1) * arrow_scale;
v = (y2 - y1) * arrow_scale;

% ------------------------
% PLOTTING (scatter + arrows)
% ------------------------
figure; hold on;
scatter(x1, y1, 60, 'filled', 'MarkerFaceColor', [0.2 0.6 0.9]);
quiver(x1, y1, u, v, 0, 'Color', [0.1 0.1 0.6], 'LineWidth', 1.5);
xlabel(sprintf('Stim A (mean of first %d)', K_Af));
ylabel(sprintf('Stim B (mean of first %d)', K_Bf));
title(sprintf('Top %.0f%% neurons – CP component %d\nArrows point to mean of last %d trials', ...
              top_percent, component_idx, max([K_Al K_Bl])));
legend({'Early mean', '→ Late mean'}, 'Location', 'best');
grid on; axis equal;

% ========= GLYPH (proporzioni su TOTALE) =========
if ~exist('u','var') || ~exist('v','var'), error('u,v are required.'); end
N = numel(u); if N==0, error('No arrows.'); end

is_NE = (u > 0) & (v > 0);
is_NO = (u < 0) & (v > 0);
is_SE = (u > 0) & (v < 0);
is_SO = (u < 0) & (v < 0);
on_axes = (u == 0) | (v == 0);

cNE = sum(is_NE); cNO = sum(is_NO); cSE = sum(is_SE); cSO = sum(is_SO);
pNE = cNE / N;    pNO = cNO / N;    pSE = cSE / N;    pSO = cSO / N;

glyph_max_len = 1.0;
L  = glyph_max_len * [pNE, pNO, pSE, pSO];
theta = [pi/4, 3*pi/4, -pi/4, -3*pi/4];
dx = L .* cos(theta); dy = L .* sin(theta);

fig = figure('Color','w'); ax = axes; hold(ax,'on'); %#ok<NASGU>
r = glyph_max_len;
rectangle('Position',[-r -r 2*r 2*r], 'Curvature',[1 1], 'LineStyle',':', 'EdgeColor',[0 0 0], 'LineWidth',0.75);
plot([-r r], [-r r], 'k:', 'LineWidth', 0.75);
plot([-r r], [ r -r], 'k:', 'LineWidth', 0.75);
quiver(0,0,dx(1),dy(1),0,'LineWidth',2); % NE
quiver(0,0,dx(2),dy(2),0,'LineWidth',2); % NO
quiver(0,0,dx(3),dy(3),0,'LineWidth',2); % SE
quiver(0,0,dx(4),dy(4),0,'LineWidth',2); % SO
axis equal; xlim([-r r]); ylim([-r r]); box on; ax.XTick=[]; ax.YTick=[];
title(sprintf('Glyph (counts/TOTAL). Early=first %d, Late=last %d', max([K_Af K_Bf]), max([K_Al K_Bl])));

pct = 100 * [pNE pNO pSE pSO];
txt = sprintf('NE %d (%.1f%%) | NO %d (%.1f%%) | SE %d (%.1f%%) | SO %d (%.1f%%) | On-axes: %d/%d',...
              cNE,pct(1), cNO,pct(2), cSE,pct(3), cSO,pct(4), sum(on_axes), N);
annotation('textbox',[0.0 0.0 1.0 0.08], 'String', txt, ...
    'EdgeColor','none', 'HorizontalAlignment','center', 'VerticalAlignment','middle', 'FontSize',9);

% ========= GLYPH PESATO PER LUNGHEZZA =========
w = hypot(u, v);
include_axes_in_denominator = true;  % true: normalize by total movement

W_den = include_axes_in_denominator * sum(w) + (~include_axes_in_denominator) * sum(w(~on_axes));
if W_den == 0, warning('Zero total movement.'); W_den = 1; end

W_NE = sum(w(is_NE)); W_NO = sum(w(is_NO));
W_SE = sum(w(is_SE)); W_SO = sum(w(is_SO));
qNE = W_NE / W_den; qNO = W_NO / W_den; qSE = W_SE / W_den; qSO = W_SO / W_den;

L2  = glyph_max_len * [qNE, qNO, qSE, qSO];
dx2 = L2 .* cos(theta); dy2 = L2 .* sin(theta);

fig = figure('Color','w'); ax = axes; hold(ax,'on'); %#ok<NASGU>
rectangle('Position',[-r -r 2*r 2*r], 'Curvature',[1 1], 'LineStyle',':', 'EdgeColor',[0 0 0], 'LineWidth',0.75);
plot([-r r], [-r r], 'k:', 'LineWidth', 0.75);
plot([-r r], [ r -r], 'k:', 'LineWidth', 0.75);
quiver(0,0,dx2(1),dy2(1),0,'LineWidth',2); % NE
quiver(0,0,dx2(2),dy2(2),0,'LineWidth',2); % NO
quiver(0,0,dx2(3),dy2(3),0,'LineWidth',2); % SE
quiver(0,0,dx2(4),dy2(4),0,'LineWidth',2); % SO
axis equal; xlim([-r r]); ylim([-r r]); box on; ax.XTick=[]; ax.YTick=[];
den_str = ternary(include_axes_in_denominator,'TOTAL','quad-defined');
title(sprintf('Glyph weighted by length (den=%s). Early=first %d, Late=last %d', ...
      den_str, max([K_Af K_Bf]), max([K_Al K_Bl])));

pct2 = 100*[qNE qNO qSE qSO];
txt2 = sprintf(['NE %.1f%% (W=%.3g) | NO %.1f%% (W=%.3g) | SE %.1f%% (W=%.3g) | SO %.1f%% (W=%.3g)\n' ...
                'On-axes: %d/%d | W_{den}=%.3g'], ...
                pct2(1),W_NE, pct2(2),W_NO, pct2(3),W_SE, pct2(4),W_SO, sum(on_axes), N, W_den);
annotation('textbox',[0.0 0.0 1.0 0.08],'String',txt2,'EdgeColor','none','HorizontalAlignment','center','FontSize',9);

% ------- Helper inline (ternary) -------
function out = ternary(cond, a, b)
    if cond, out = a; else, out = b; end
end
