
#---- NOTE ----
# Nella simulazione si adotta un design event-driven, cioè 
# la durata dello studio D è generata come funzione di 
# lambda_v e lambda_c attraverso una regola di arresto a n_target eventi.
# Tuttavia, nell'inferenza bayesiana si assume che D sia osservato e fisso,
# coerentemente con il modello gerarchico avente verosimiglianza condizionata a D.
# Il design event-driven adatta la durata al livello di rischio: l'inferenza  
# risulta molto più robusta se la malattia è meno frequente del previsto e lo studio
# clinico è meno costoso se la malattia è più frequente del previsto.

#---- CODICE ----
rm(list = ls())  # pulisce l'ambiente
gc()             # libera memoria

# ---- FUNZIONE ACCEPT-REJECT PER IL CAMPIONAMENTO PARABOLICO ----
sample_accept_reject <- function(N, a, b, B) {
  
  # controllo parametri
  if (!is.numeric(N) || length(N) != 1 || N <= 0) stop("N deve essere un intero positivo.")
  if (!isTRUE(all.equal(N, as.integer(N)))) stop("N deve essere un intero.")
  if (!is.numeric(B) || length(B) != 1 || B <= 0) stop("B deve essere positivo.")
  if (!is.numeric(a) || length(a) != 1 || a < 0 || a > 2) stop("a deve essere nel range [0, 2].")
  bmin <- -a - 1
  bmax <- 1 - 2*a
  if (!is.numeric(b) || length(b) != 1 || b < bmin || b > bmax) {
    stop(sprintf("b deve essere nel range [%g, %g] per avere una densità concava.", bmin, bmax))
  }
  
  # campionamento uniforme
  if (a == 0 && b==0) {return(runif(N, 0, B))}
  # coefficienti della densità parabolica concava su [0,B]
  A1 <- -3 * a / B^3
  A2 <- 2 * (3 * a + b) / B^2
  A3 <- (1 - 2 * a - b) / B
  times <- numeric(N)
  
  # campionamento lineare
  if (a == 0) {
    f_max <- max(A2*B + A3, A3)
    i <- 1
    while (i <= N) {
      x <- runif(1, 0, B)
      y <- runif(1, 0, f_max)
      fx <- A2 * x + A3
      if (y <= fx) {
        times[i] <- x
        i <- i + 1
      }
    } 
    
    # campionamento parabolico    
  } else {
    r_star <- (3 * a + b) / (3 * a) * B # vertice della parabola
    f_max <- max(A1 * r_star^2 + A2 * r_star + A3, A1 * B^2 + A2 * B + A3, A3)
    i <- 1
    while (i <= N) {
      x <- runif(1, 0, B)
      y <- runif(1, 0, f_max)
      fx <- A1 * x^2 + A2 * x + A3
      if (y <= fx) {
        times[i] <- x
        i <- i + 1
      }
    }  
  }
  return(times)
} 

# ---- SELEZIONE DELLA DISTRIBUZIONE DI RECLUTAMENTO ----
# dist_flag: "uniform" | "parabolic" | "beta"
sample_recruitment_times <- function(N, dist_flag = "uniform", B,
                                     a = 0, b = 0, shape1 = 2, shape2 = 5) {
  dist_flag <- tolower(as.character(dist_flag))
  dist_flag <- match.arg(dist_flag, c("uniform", "parabolic", "beta"))
  if (dist_flag == "uniform") {
    return(runif(N, 0, B))
  } else if (dist_flag == "parabolic") {
    return(sample_accept_reject(N, a, b, B))
  } else { # "beta"
    return(rbeta(N, shape1 = shape1, shape2 = shape2) * B)
  }
}

# ---- FUNZIONE DI SIMULAZIONE ----
simula_trial <- function(N, n_target, a, b, D, B, lambda_c, VE,
                         dist_flag = "uniform", shape1 = 2, shape2 = 2) {
  
  lambda_v <- (1 - VE) * lambda_c
  group <- rep(c("v", "c"), each = N/2) #sample(c("v", "c"), N, replace = TRUE) 
  rates <- ifelse(group == "v", lambda_v, lambda_c)
  rec_times <- sample_recruitment_times(N, dist_flag = dist_flag, B = B,
                                        a = a, b = b, shape1 = shape1, shape2 = shape2) 
  TTE <- rexp(N, rates) 
  
  # ---- DETERMINAZIONE DURATA DELLO STUDIO ----
  # D <- sort(rec_times + TTE)[n_target]             # tempo dell'n-esimo evento
  # if (D<B){
  #   B <- D                                         # il reclutamento si arresta all'n-esimo evento
  # }
  # ---- TEMPI DI SORVEGLIANZA ----
  surv_times <- pmin(D-rec_times,TTE)               # tempi di sorveglianza individuali
  in_study <- rec_times <= D                        # chi è entrato prima della fine
  
  # ---- DIMENSIONE CAMPIONARIA E DURATA RECLUTAMENTO OSSERVATI ----
  N_eff <- sum(in_study)                            # numero di soggetti nello studio
  B_eff <- max(rec_times[in_study])                 # durata effettiva del reclutamento
  #B <- B_eff
  # ---- OUTPUT AGGREGATO PER GRUPPO ----
  s_v <- sum(surv_times[group == "v" & in_study])   # sorveglianza totale vaccino
  s_c <- sum(surv_times[group == "c" & in_study])   # sorveglianza totale placebo
  x_v <- sum(group == "v" & (rec_times + TTE) <= D) # numero di casi vaccinati
  x_c <- sum(group == "c" & (rec_times + TTE) <= D) # numero di casi non vaccinati
  N_v <- sum(group == "v" & in_study)               # numero di vaccinati nello studio
  N_c <- sum(group == "c" & in_study)               # numero di non vaccinati nello studio
  
  return(list(D = D, N = N, N_eff = N_eff, B_eff = B_eff, B = B,
              x_total = x_v + x_c, s_v = s_v, s_c = s_c, x_v = x_v, 
              x_c = x_c, N_v = N_v, N_c = N_c))
}

# ---- MODELLO COMPLETO JAGS E INFERENZA BAYESIANA ----

run_jags_inference <- function(x_v, x_c, s_v, s_c, N_v, N_c, D,
                               VE_min = 0.3,
                               n_iter = 20000, n_burnin = 2000, n_chains = 3, thin = 20) {
  library(rjags)
  library(coda)
  
  model_string <- "
  model {
    
    # ---- GRUPPO VACCINO ----
    
    # ---- Prior su E[min(T_v,C)] ----
    Em_v ~ dunif(0,D)
    
    # ---- Prior su P(T_v < C) ----
    p_v <- (1-VE)*p_c*Em_v/Em_c 
    
    # ---- Prior su E[min(T_v,C)^2] ----
    varm_v ~ dunif(0,D^2)
    Em2_v <- Em_v^2 + varm_v
    
    #---- GRUPPO CONTROLLO ----
    
    # ---- Prior su E[min(T_c,C)] ----
    Em_c ~ dunif(0,D)
            
    # ---- Prior su P(T_c < C) ----
    p_c ~ dunif(0,1) #dbeta(0.1*Em_c/(1-0.1*Em_c),1)
    
    # ---- Prior su E[min(T_c,C)^2] ----
    varm_c ~ dunif(0,D^2)
    Em2_c <- Em_c^2 + varm_c
    
    # ---- VEROSIMIGLIANZA ----
    #(X_v, S_v) ~ normale bivariata
    mu_xv <- N_v*p_v
    mu_sv <- N_v*Em_v
    var_xv <- N_v*p_v*(1-p_v)
    var_sv <- N_v*(Em2_v - Em_v^2)
    cov_v <- N_v*p_v*(0.5*Em2_v/Em_v - Em_v) 
    mu_sv_cond <- mu_sv + cov_v/(var_xv) * (x_v - mu_xv)
    var_sv_cond <- var_sv - cov_v^2/var_xv
    x_v ~ dbin(p_v,N_v) #dnorm(mu_xv, 1/var_xv) #dpois(lambda_v*s_v)
    s_v ~ dnorm(mu_sv_cond, 1/var_sv_cond)
    
    # (X_c, S_c) ~ normale bivariata
    mu_xc <- N_c*p_c
    mu_sc <- N_c*Em_c
    var_xc <- N_c*p_c*(1-p_c)
    var_sc <- N_c*(Em2_c - Em_c^2)
    cov_c <- N_c*p_c*(0.5*Em2_c/Em_c - Em_c) 
    mu_sc_cond <- mu_sc + cov_c/(var_xc) * (x_c - mu_xc)
    var_sc_cond <- var_sc - cov_c^2/var_xc
    x_c ~ dbin(p_c,N_c) #dnorm(mu_xc, 1/var_xc) #dpois(lambda_c*s_c)
    s_c ~ dnorm(mu_sc_cond, 1/var_sc_cond)
    
    # ---- Prior Pfizer su VE ----
    theta_hat <- (1-0.3)/(2-0.3)
    a <- theta_hat/(1-theta_hat)
    theta ~ dbeta(a,1)
    VE <- 1 - theta/(1-theta)
    
    # ---- Prior esponenziale su (1 - VE) mediante CDF inversa ----
    # a <- 1/(1-0.3)
    # u ~ dunif(0,1)
    # VE <- 1 + log(1-u)/a 
    
  }
  "
  
  jags_data <- list(
    x_c = x_c,
    x_v = x_v,
    s_v = s_v,
    s_c = s_c,
    D = D,
    N_v = N_v,
    N_c = N_c
  )
  
  model <- jags.model(textConnection(model_string), data = jags_data,  
                      n.chains = n_chains, n.adapt = 5000)
  
  update(model, n.iter = n_burnin) # scarta i primi n_burnin campioni a posteriori
  
  samples <- coda.samples(model, variable.names = c("VE"), n.iter = n_iter, thin = thin)
  ve_samples <- as.matrix(samples)[, "VE"] # combina tutte le catene in un'unica matrice
  prob_VE_full <- mean(ve_samples > VE_min)
  ci_full <- quantile(ve_samples, c(0.025, 0.975))
  
  # Metodo A di Ewell
  # a_v <- 1
  # a_c <- (2-0.3)/(1-0.3)
  # b <- a_c/0.15
  # theta_hat <- (1-0.3)/(2-0.3)
  # a_v <- theta_hat/(1-theta_hat)
  # a_c <- 1
  # b_c <- a_c/0.15
  # b_v <- b_c*s_v/s_c
  # lambda_v_samples <- rgamma(12000, a_v + x_v, rate = b_v + s_v)
  # lambda_c_samples <- rgamma(12000, a_c + x_c, rate = b_c + s_c)
  # ve_samples_cond <- 1 - lambda_v_samples/lambda_c_samples
  # prob_VE_cond <- mean(ve_samples_cond > VE_min)
  # ci_cond <- quantile(ve_samples_cond, c(0.025, 0.975))
  
  # Modello binomiale condizionato
  theta_hat <- (1-0.3)/(2-0.3)
  a <- theta_hat/(1-theta_hat)
  a_post <- a + x_v
  b_post <- 1 + x_c
  theta_samples <- rbeta(3000,a + x_v, 1 + x_c)
  ve_samples_cond <- ((1-theta_samples)-theta_samples*s_c/s_v)/(1-theta_samples)
  prob_VE_cond <- mean(ve_samples_cond > VE_min)
  ci_cond <- quantile(ve_samples_cond, c(0.025, 0.975))
  
  # Metodo B di Ewell
  z <- qnorm(0.975,0,1)
  xv <- x_v + 0.5
  xc <- x_c + 0.5
  sv <- s_v + 0.5
  sc <- s_c + 0.5
  ci_ew <- c(1 - exp(log(xv*sc/(xc*sv)) + z*sqrt(1/xv + 1/xc)), 
             1 - exp(log(xv*sc/(xc*sv)) - z*sqrt(1/xv + 1/xc)))
  
  # Metodo Clopper-Pearson
  L <- qbeta(0.025, x_v, 1 + x_c)
  H <- qbeta(0.975, 1 + x_v, x_c)
  ci_cp <- c(1 - H/(1 - H)*s_c/s_v, 1 - L/(1 - L)*s_c/s_v)
  
  return(list(
    prob_VE_full = prob_VE_full,
    VE_CI_full = ci_full,
    VE_mean_full = mean(ve_samples),
    VE_median_full = median(ve_samples),
    VE_sd_full = sd(ve_samples),
    prob_VE_cond = prob_VE_cond,
    VE_CI_cond = ci_cond,
    VE_mean_cond = mean(ve_samples_cond),
    VE_median_cond = median(ve_samples_cond),
    VE_sd_cond = sd(ve_samples_cond),
    VE_CI_ew = ci_ew,
    VE_CI_cp = ci_cp,
    MLE_v = x_v/s_v,
    MLE_c = x_c/s_c,
    conv = max(gelman.diag(samples)$psrf), # Convergenza tra catene (PSRF ≈ 1)
    autcorr = autocorr.diag(samples)[2,"VE"], # Controlla autocorrelazione
    ess = floor(effectiveSize(samples)["VE"])
  ))
}

# ---- STUDI MULTIPLI PER SCENARI MULTIPLI ----

set.seed(2)
N_sim <- 10000

# Incidenza della malattia e scenari per l'efficacia vaccinale
lambda_c <- 0.1
VE_values <- c(0.1, 0.3, 0.5, 0.7, 0.9)

# Tipo di reclutamento e relativi parametri:
# Parabolico:
a_parab <- 2
b_parab <- -3  
# Beta:
beta_shape1 <- 2
beta_shape2 <- 2
dist_flags <- c("uniform", "beta") #,"parabolic")  

# Parametri di controllo
#N <- 40000
B <- 0.75
D <- 1
n_target_values <- c(40, 80, 160, 900)
# Funzione per simulare un intero scenario (VE, n_target, dist_flag)
simulate_scenario <- function(VE, n_target, dist_flag) {
  # Stima del numero di partecipanti 
  lambda_v <- (1-VE)*lambda_c
  prob <- lambda_v/(lambda_c+lambda_v) 
  n_target_v <- prob*n_target
  n_target_c <- (1-prob)*n_target
  if (dist_flag == "beta") {
    fun_c <- function(R){pexp(D-B*R,lambda_c)*dbeta(R,beta_shape1,beta_shape2)}
    fun_v <- function(R){pexp(D-B*R,lambda_v)*dbeta(R,beta_shape1,beta_shape2)}
  } else {
    fun_c <- function(R){pexp(D-B*R,lambda_c)*dunif(R,0,1)}
    fun_v <- function(R){pexp(D-B*R,lambda_v)*dunif(R,0,1)}
  }  
  p_c <- integrate(fun_c, 0, 1)$value
  p_v <- integrate(fun_v, 0, 1)$value
  N <- round((n_target*2)/(p_c + p_v))
  if (N %% 2 != 0) N <- N + 1  # rende N pari se è dispari
  trials <- replicate(N_sim, {
    # 1) genera un singolo trial 
    sim <- simula_trial(N, n_target, a = a_parab, b = b_parab, D = D, B = B,
                        lambda_c = lambda_c, VE = VE,
                        dist_flag = dist_flag,
                        shape1 = beta_shape1, shape2 = beta_shape2)
    
    # 2) inferenza bayesiana
    infer <- run_jags_inference(
      x_v = sim$x_v, x_c = sim$x_c,
      s_v = sim$s_v, s_c = sim$s_c,
      N_v = sim$N_v, N_c = sim$N_c,
      D = sim$D, VE_min = 0.3
    )
    
    # 3) post-processing
    # intervalli di credibilità del metodo full
    ci_vals_full <- infer$VE_CI_full 
    infer$VE_CI_full <- NULL
    infer$VE_CI_low_full <- ci_vals_full[1] 
    infer$VE_CI_upp_full <- ci_vals_full[2]
    
    # intervalli di credibilità del metodo binomiale condizionato
    ci_vals_cond <- infer$VE_CI_cond 
    infer$VE_CI_cond <- NULL
    infer$VE_CI_low_cond <- ci_vals_cond[1] 
    infer$VE_CI_upp_cond <- ci_vals_cond[2]
    
    # intervalli di confidenza del metodo B di Ewell
    ci_vals_ew <- infer$VE_CI_ew 
    infer$VE_CI_ew <- NULL
    infer$VE_CI_low_ew <- ci_vals_ew[1] 
    infer$VE_CI_upp_ew <- ci_vals_ew[2]
    
    # intervalli di confidenza di Clopper-Pearson
    ci_vals_cp <- infer$VE_CI_cp 
    infer$VE_CI_cp <- NULL
    infer$VE_CI_low_cp <- ci_vals_cp[1] 
    infer$VE_CI_upp_cp <- ci_vals_cp[2]
    
    as.data.frame(c(sim, infer))
  }, simplify = FALSE)
  
  trials_df <- do.call(rbind, lapply(trials, as.data.frame))
  trials_df$VE <- VE
  trials_df$n_target <- n_target
  trials_df$dist_flag <- dist_flag
  return(trials_df)
}

# costruisci la griglia di scenari e lanciali tutti
grid <- expand.grid(VE = VE_values, n_target = n_target_values, dist_flag = dist_flags, KEEP.OUT.ATTRS = FALSE, 
                    stringsAsFactors = FALSE)

start_time <- proc.time()

all_results <- lapply(seq_len(nrow(grid)), function(i) {
  simulate_scenario(grid$VE[i], grid$n_target[i], grid$dist_flag[i])
})

combined_df <- do.call(rbind, all_results)
save(combined_df, file = "new_full-likelihood.RData")

end_time <- proc.time()
total_time <- round((end_time - start_time)["elapsed"], 2)

# ---- STATISTICHE DESCRITTIVE PER SCENARIO (VE, n_target) ----
library(dplyr)

summary_stats <- combined_df %>%
  select(VE, n_target, dist_flag, D, N_v, N_c, x_total, x_v, x_c, s_v, s_c) %>%
  group_by(VE, n_target, dist_flag) %>%
  summarise(across(where(is.numeric), mean, .names = "mean_{.col}"), .groups = "drop")

# ---- Soglie ottimali e metriche PER ciascun n_target ----
# th_* calcolate sulla distribuzione di prob_VE_* quando VE=0.3, separatamente per n_target
th_by_nt <- combined_df %>%
  filter(VE == 0.3) %>%
  group_by(n_target, dist_flag) %>%
  summarise(
    th_full_opt = quantile(prob_VE_full, 0.975),
    th_cond_opt  = quantile(prob_VE_cond,  0.975),
    .groups = "drop"
  )

decision_summary <- combined_df %>%
  left_join(th_by_nt, by = c("n_target", "dist_flag")) %>%
  group_by(VE, n_target, dist_flag) %>%
  summarise(
    N_sim = n(),
    #CI_low_ew = mean(VE_CI_low_ew),
    #CI_upp_ew = mean(VE_CI_upp_ew),
    #CI_low_cp = mean(VE_CI_low_cp),
    #CI_upp_cp = mean(VE_CI_upp_cp),
    power_full = mean(prob_VE_full > first(th_full_opt)),
    coverage_full = mean(VE_CI_low_full < VE & VE < VE_CI_upp_full),
    power_cond = mean(prob_VE_cond > first(th_cond_opt)),
    coverage_cond = mean(VE_CI_low_cond < VE & VE < VE_CI_upp_cond),
    length_reduct = 1 - mean((VE_CI_upp_full - VE_CI_low_full)/(VE_CI_upp_cond - VE_CI_low_cond)),
    mse_mean_reduct = 1 - mean((VE_mean_full - VE)^2)/mean((VE_mean_cond - VE)^2),
    #var_mean_reduct = 1 - var(VE_mean_full)/var(VE_mean_cond),
    #bias2_mean_reduct = 1 - (mean(VE_mean_full - VE))^2/(mean(VE_mean_cond - VE))^2,
    PSFR = mean(conv),
    corr = mean(autcorr),
    ess = mean(ess),
    .groups = "drop"
  )

cat("\nTempo totale simulazione:", total_time/60, "minuti\n")

cat("===== STATISTICHE DESCRITTIVE =====\n")
print(
  summary_stats %>%
    dplyr::mutate(dplyr::across(where(is.numeric), ~ round(.x, 3)))
)


cat("\n===== POTENZA/ERRORE I TIPO & DIAGNOSTICHE (per VE, n_target) =====\n")
print(
  decision_summary %>%
    dplyr::mutate(dplyr::across(where(is.numeric), ~ round(.x, 3)))
)

