

% the second case

clear
close all
clc
rng("default")
%% load losses
train_loss_0 = readNPY("train_losses_0.npy");
test_loss_0 = readNPY("test_losses_0.npy");

train_loss_1 = readNPY("train_losses_1.npy");
test_loss_1 = readNPY("test_losses_1.npy");

train_loss_2 = readNPY("train_losses_2.npy");
test_loss_2 = readNPY("test_losses_2.npy");

train_loss_3 = readNPY("train_losses_3.npy");
test_loss_3 = readNPY("test_losses_3.npy");

%% ecdf

figure('Name','train_loss_0')

[train_loss_0_ecdf,train_loss_0_] = ecdf(train_loss_0);
ecdf(train_loss_0)

figure('Name','train_loss_1')

[train_loss_1_ecdf,train_loss_1_] = ecdf(train_loss_1);
ecdf(train_loss_1)

figure('Name','train_loss_2')

[train_loss_2_ecdf,train_loss_2_] = ecdf(train_loss_2);
ecdf(train_loss_2)

figure('Name','train_loss_3')

[train_loss_3_ecdf,train_loss_3_] = ecdf(train_loss_3);
ecdf(train_loss_3)

%% ecdf   (1:390)
test_loss_0_ecdf = fun(train_loss_0_(1:395), train_loss_0_ecdf, test_loss_0);
test_loss_1_ecdf = fun(train_loss_1_(1:395), train_loss_1_ecdf, test_loss_1);
test_loss_2_ecdf = fun(train_loss_2_(1:395), train_loss_2_ecdf, test_loss_2);
test_loss_3_ecdf = fun(train_loss_3_(1:395), train_loss_3_ecdf, test_loss_3);

%% max

ecdf = [test_loss_0_ecdf;test_loss_1_ecdf;test_loss_2_ecdf;test_loss_3_ecdf];
ecdf_max = max(ecdf);

save("loss_ecdf_max.mat","ecdf_max")

figure('Name','test_ecdf_max'),hold on
histogram(ecdf_max(1:51),'NumBins',100)
histogram(ecdf_max(52:end),'NumBins',100)

%% evaluate
close all
codes = ecdf_max;
normal_max = ecdf_max(1:51);
anomaly_max = ecdf_max(52:end);


%% ranking AEs
close all

normal = codes(1:51);
anomaly = codes(52:end);

for i = 1:floor(499/20)
    test_codes(i,:) = [normal, anomaly((i-1)*20+1:i*20)];
end

labels = [zeros(1,51), ones(1,20)];
for i=1:floor(499/20)
    auc_aes(i) = 0;
    [~, ~, ~, auc_aes(i)] = perfcurve(labels, test_codes(i,:), 1);
end
auc_AEs_mean = mean(auc_aes)


sorted_codes = sort(test_codes,2);

for i=1:floor(499/20)
    a(i) = 0;
    a_f(i) = 0;
    n(i) = 0;
    n_f(i) = 0;
    TNR_SAE(i) = 0;
    FNR_SAE(i) = 0;
    TPR_SAE(i) = 0;
    FPR_SAE(i) = 0;
    for j = 52:71
        index = find(test_codes(i,:) == sorted_codes(i,j));
        test_codes(i,index(end)) = nan;
        if index(end) > 51
            a(i) = a(i) + 1;
        else
            a_f(i) = a_f(i) + 1;
        end
        TNR_SAE(i) = a(i)/20;
        FPR_SAE(i) = a_f(i)/20;
    end
    for j = 1:51
        index = find(test_codes(i,:) == sorted_codes(i,j));
        test_codes(i,index(end)) = nan;
        if index(end) < 52
            n(i) = n(i) + 1;
        else
            n_f(i) = n_f(i) + 1;
        end
        TPR_SAE(i) = n(i)/51;
        FNR_SAE(i) = n_f(i)/51;
    end
end
tolal_TNR_SAE = mean(TNR_SAE)
tolal_FNR_SAE = mean(FNR_SAE)
tolal_TPR_SAE = mean(TPR_SAE)
tolal_FPR_SAE = mean(FPR_SAE)

acc = (a+n)/71;
acc_SAE = mean(acc)

return
% a = 0;
% a_f = 0;
% n = 0;
% n_f = 0;
% 
% for j = 52:550
%     index = find(codes(:) == sorted_codes(j));
%     codes(index(end)) = nan;
%     if index(end) > 51
%         a = a + 1;
%     else
%         a_f = a_f + 1;
%     end
% end
% 
% TNR = a/499
% FNR = a_f/499
% 
% for j = 1:51
%     index = find(codes(:) == sorted_codes(j));
%     codes(index(end)) = nan;
%     if index(end) < 52
%         n = n + 1;
%     else
%         n_f = n_f + 1;
%     end
% end
% 
% TPR = n/51
% FPR = n_f/51
% 
% acc_max = (TNR+TPR)/2


a = 0;
a_f = 0;
n = 0;
n_f = 0;

for i = 52:550
    if ecdf_max(i) == 1
        a = a + 1;
    else
        a_f = a_f + 1;
    end
end

TNR = a/499;
FNR = a_f/499;

for i = 1:51
    if ecdf_max(i) < 1
        n = n + 1;
    else
        n_f = n_f + 1;
    end
end

TPR = n/51;
FPR = n_f/51;

acc_max = (TNR+TPR)/2;

