

clear
close all
clc
rng('default')

%% 
train_codes = readNPY('train_codes_SSA.npy'); 
val_codes = readNPY('valid_codes_SSA.npy');
test_codes = readNPY('test_codes_SSA.npy');

normal_test = test_codes(1:526,:);
anomaly_test = test_codes(527:end,:);
% for i = 1:floor(2081/20)
%     codes(i,:,:) = [normal_test; anomaly_test((i-1)*20+1:i*20,:)];
% 
% end

% codes = [normal_test; anomaly_test(1:20,:)];

%% GMM
% 
gmm = fitgmdist(train_codes, 4);
% probs = pdf(gmm, test_codes(1:155,:));
probs = pdf(gmm, train_codes);
train_scores = -log(probs);

figure("name","trian_GMM"),hold on
h = histogram(train_scores);
hold off

save("SAE_GMM_train.mat","train_scores");


% probs = pdf(gmm, codes);
probs = pdf(gmm, test_codes);

test_scores = -log(probs);

for i=1:length(test_scores)
    if test_scores(i) == inf
        test_scores(i) = 'a';
    end
end

b = max(test_scores);

for i=1:length(test_scores)
    if test_scores(i) == 'a'
        test_scores(i) = b;
    end
end

figure("name","test_GMM"),hold on
h_normal = histogram(test_scores(1:526,:),NumBins=50);
h_anormaly = histogram(test_scores(527:end,:),NumBins=50);
hold off

save("SAE_GMM_test.mat","test_scores");

