% Step 1: Load data 
close all;
clear;
load('MMF_transferFunctions_OM4_400m.mat');
load('DataBER_106.25G_OM4.mat');
load('DataLmax_106.25G_OM4.mat');

% Step 2: Extract relevant variables
Lambdas = Lambdas;
Laszers = Laser;
frequencies = freq;
Transfer_fct_MMF_OM4 = Hf_MMF_OM4;
EMB_OM4 = EMB_OM4;
Group_delay_OM4 = taug_OM4;
numWavelengths = size(Hf_MMF_OM4, 1);
numLasers = size(Hf_MMF_OM4, 2);
numFibers = size(Hf_MMF_OM4, 4);

% Extract transfer functions for the chosen wavelength and laser
laser_index=1;
wavelength_index = 1;
fiber_index = 1;
% Define lengths
L = 30;
Lref = 400;
K = Lref / L;

% Set reference distance and reference transfer function at 400m for a specific laser and wavelength
Lref = 400; % Reference distance in meters
reference_TF = abs(squeeze(Hf_MMF_OM4(wavelength_index, laser_index, :, 1))); % Absolute value for magnitude
reference_TF_dB = 10*log10(reference_TF/max(reference_TF)); % Convert to dB normalized to max

% Initialize new arrays for the scaled transfer functions
new_Transfer_fct = zeros(size(Transfer_fct_MMF_OM4));
New_HF_30dB = zeros(size(Transfer_fct_MMF_OM4));

% Loop over all wavelengths, lasers, and fibers
for lambda_idx = 1:size(Transfer_fct_MMF_OM4, 1)
    for laser_idx = 1:size(Transfer_fct_MMF_OM4, 2)
        for fiber_idx = 1:size(Transfer_fct_MMF_OM4, 4)
            % Original transfer function for this combination
            TF = abs(squeeze(Transfer_fct_MMF_OM4(lambda_idx, laser_idx, :, fiber_idx)));

            % Scale the frequencies
            freq_scaled = frequencies * K;

            % Interpolate the scaled transfer function
            TF_scaled_30 = interp1(freq_scaled, TF, frequencies, 'linear', 'extrap');
            TF_scaled_30_dB = 10 * log10(TF_scaled_30 / max(reference_TF));

            % Store the scaled transfer function
            new_Transfer_fct_30(lambda_idx, laser_idx, :, fiber_idx) = TF_scaled_30;
            New_HF_30dB(lambda_idx, laser_idx, :, fiber_idx) = TF_scaled_30_dB;
        end
    end
end

% Step 3: Calculate Bandwidths
% Initialize an array to store the equivalent bandwidths
equivalent_bandwidths_all = zeros(numWavelengths, numLasers, numFibers);
bandwidths_3dB_all = zeros(numWavelengths, numLasers, numFibers);
bandwidths_5dB_all = zeros(numWavelengths, numLasers, numFibers);
bandwidths_10dB_all = zeros(numWavelengths, numLasers, numFibers);

% Iterate over all wavelengths, lasers, and fibers
for wavelength_index = 1:numWavelengths
    for laser_index = 1:numLasers
        for fiber_index = 1:numFibers
            % Extract the transfer function for this specific combination
            TF_selected = squeeze(new_Transfer_fct_30(wavelength_index, laser_index, :, fiber_index));

            % Calculate the equivalent bandwidth for the current combination
            equivalent_bandwidths_all(wavelength_index, laser_index, fiber_index) = calculateEquivalentBandwidth(TF_selected, freq);
            bandwidths_3dB_all(wavelength_index, laser_index, fiber_index) = calculateMinus3dBBandwidth(TF_selected, freq);
            bandwidths_5dB_all(wavelength_index, laser_index, fiber_index) = calculateMinus5dBBandwidth(TF_selected, freq);
            bandwidths_10dB_all(wavelength_index, laser_index, fiber_index) = calculateMinus10dBBandwidth(TF_selected, freq);
        end
    end
end

% Step 4: Reshape data into vectors
vector_equiv = reshape(equivalent_bandwidths_all, [], 1);
vector_3dB = reshape(bandwidths_3dB_all, [], 1);
vector_5dB = reshape(bandwidths_5dB_all, [], 1);
vector_10dB = reshape(bandwidths_10dB_all, [], 1);

Lmax_all = squeeze(Lmax_OM4(:, :, 1, :));
Lmax_all_vector = reshape(Lmax_all, [], 1);

% Step 5: Clean data (remove NaNs and Infs)
validIndex = ~(isnan(Lmax_all_vector) | isinf(Lmax_all_vector));

Lmax_all_vector_cleaned = Lmax_all_vector(validIndex);
vector_3dB_cleaned = vector_3dB(validIndex);
vector_5dB_cleaned = vector_5dB(validIndex);
vector_10dB_cleaned = vector_10dB(validIndex);
vector_equiv_cleaned = vector_equiv(validIndex);

% Combine the cleaned vectors into a feature matrix
X = [vector_3dB_cleaned, vector_5dB_cleaned, vector_10dB_cleaned, vector_equiv_cleaned];
Y = Lmax_all_vector_cleaned;

% Step 6: Data Augmentation
noise_factor = 0.05;
X_augmented = X + noise_factor * randn(size(X));
X = [X; X_augmented];
Y = [Y; Y]; % Duplicate Y for augmented data

% Normalize the data
X = (X - min(X)) ./ (max(X) - min(X));  % Min-max normalization

% Step 7: Cross-Validation Setup
k = 5; % 5-Fold Cross-Validation
cv = cvpartition(size(X,1), 'KFold', k);

% Initialize arrays to store cross-validation results
mseValues = zeros(k, 1);
maxDifferences = zeros(k, 1);
r2Scores = zeros(k, 1);

for i = 1:k
    % Get the training and testing indices for this fold
    trainIdx = training(cv, i);
    testIdx = test(cv, i);

    XTrain = X(trainIdx, :);
    YTrain = Y(trainIdx, :);
    XTest = X(testIdx, :);
    YTest = Y(testIdx, :);

    % Step 8: Define the neural network with tuned hyperparameters
    net = fitnet([20 15 10]);

    % Configure the training function (Levenberg-Marquardt backpropagation)
    net.trainFcn = 'trainlm';

    % Adjust the training parameters
    net.trainParam.lr = 0.005;  % Reduced learning rate
    net.trainParam.epochs = 1500;  % Increase number of epochs
    net.performParam.regularization = 0.2;  % Stronger regularization

    % Train the neural network
    [net, tr] = train(net, XTrain', YTrain');

    % Step 9: Evaluate the model on the test set
    YPred = net(XTest');

    % Calculate accuracy metrics
    mseError = mse(net, YTest', YPred);
    SS_res = sum((YTest' - YPred).^2);
    SS_tot = sum((YTest' - mean(YTest')).^2);
    R2 = 1 - (SS_res / SS_tot);
    differences = abs(YPred - YTest');
    maxDifference = max(differences);

    % Store results for this fold
    mseValues(i) = mseError;
    maxDifferences(i) = maxDifference;
    r2Scores(i) = R2;

    % Display results for this fold
    fprintf('Fold %d:\n', i);
    fprintf('Mean Squared Error: %f\n', mseError);
    fprintf('R² Score: %f\n', R2);
    fprintf('Maximum Difference: %f\n', maxDifference);
end

% Step 10: Cross-Validation Summary
fprintf('\nCross-Validation Results:\n');
fprintf('Mean MSE: %f\n', mean(mseValues));
fprintf('Mean Maximum Difference: %f\n', mean(maxDifferences));
fprintf('Mean R² Score: %f\n', mean(r2Scores));

% Function to calculate the -3dB bandwidth
function bandwidth_3dB = calculateMinus3dBBandwidth(TF, freq)
    % Convert the transfer function to dB
    TF_dB = 10 * log10(abs(TF));

    % Find the maximum magnitude in dB
    max_TF_dB = max(TF_dB);

    % Find the frequencies where the magnitude is greater than -3dB from the peak
    within_3dB_indices = find(TF_dB >= max_TF_dB - 3);

    % Find the -3dB bandwidth, which is the difference between the first and last frequencies within the -3dB range
    if ~isempty(within_3dB_indices)
        bandwidth_3dB = freq(within_3dB_indices(end)) - freq(within_3dB_indices(1));
    else
        bandwidth_3dB = NaN; % Return NaN if there's no bandwidth within -3dB
    end
end

% Function to calculate the -5dB bandwidth
function bandwidth_5dB = calculateMinus5dBBandwidth(TF, freq)
    % Convert the transfer function to dB
    TF_dB = 10 * log10(abs(TF));

    % Find the maximum magnitude in dB
    max_TF_dB = max(TF_dB);

    % Find the frequencies where the magnitude is greater than -5dB from the peak
    within_5dB_indices = find(TF_dB >= max_TF_dB - 5);

    % Find the -5dB bandwidth, which is the difference between the first and last frequencies within the -5dB range
    if ~isempty(within_5dB_indices)
        bandwidth_5dB = freq(within_5dB_indices(end)) - freq(within_5dB_indices(1));
    else
        bandwidth_5dB = NaN; % Return NaN if there's no bandwidth within -5dB
    end
end

% Function to calculate the -10dB bandwidth
function bandwidth_10dB = calculateMinus10dBBandwidth(TF, freq)
    % Convert the transfer function to dB
    TF_dB = 10 * log10(abs(TF));

    % Find the maximum magnitude in dB
    max_TF_dB = max(TF_dB);

    % Find the frequencies where the magnitude is greater than -10dB from the peak
    within_10dB_indices = find(TF_dB >= max_TF_dB - 10);

    % Find the -10dB bandwidth, which is the difference between the first and last frequencies within the -10dB range
    if ~isempty(within_10dB_indices)
        bandwidth_10dB = freq(within_10dB_indices(end)) - freq(within_10dB_indices(1));
    else
        bandwidth_10dB = NaN; % Return NaN if there's no bandwidth within -10dB
    end
end

% Function to calculate the equivalent bandwidth
function equivalent_bandwidth = calculateEquivalentBandwidth(TF, freq)
    % Square the magnitude of the transfer function
    TF_squared = abs(TF).^2;

    % Integrate over frequency to find total power
    total_power = trapz(freq, TF_squared);

    % Calculate the equivalent bandwidth based on total power
    equivalent_bandwidth = total_power / max(TF_squared);
end
