Skip to contents

Step 2

recovery <- binaryRL::rcv_d(
  data = binaryRL::Mason_2024_G2,
  policy = c("off", "on"),
  model_names = c("TD", "RSTD", "Utility"),
  simulate_models = list(binaryRL::TD, binaryRL::RSTD, binaryRL::Utility),
  simulate_lower = list(c(0, 0), c(0, 0, 0), c(0, 0, 0)),
  simulate_upper = list(c(1, 1), c(1, 1, 1), c(1, 1, 1)),
  fit_models = list(binaryRL::TD, binaryRL::RSTD, binaryRL::Utility),
  fit_lower = list(c(0, 0), c(0, 0, 0), c(0, 0, 0)),
  fit_upper = list(c(1, 10), c(1, 1, 10), c(1, 1, 10)),
  iteration_s = 100,
  iteration_f = 100,
  seed = 123,
  nc = 16,
  #algorithm = "L-BFGS-B"     # Gradient-Based (stats::optim)
  #algorithm = "GenSA"        # Simulated Annealing (GenSA)
  #algorithm = "GA"           # Genetic Algorithm (GA)
  #algorithm = "DEoptim"      # Differential Evolution (DEoptim)
  #algorithm = "PSO"          # Particle Swarm Optimization (pso)
  #algorithm = "Bayesian"     # Bayesian Optimization (mlrMBO)
  #algorithm = "CMA-ES"       # Covariance Matrix Adapting (cmaes)
  # Nonlinear Optimization (nloptr)
  algorithm = c("NLOPT_GN_MLSL", "NLOPT_LN_BOBYQA")
)

result <- dplyr::bind_rows(recovery) %>%
  dplyr::select(simulate_model, fit_model, iteration, everything())

write.csv(result, file = "../OUTPUT/result_recovery.csv", row.names = FALSE)

Parameter Recovery

Plot Function
plot_param_rcv <- function(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model,
  param_index,
  param_name,
  rate,
  filename
){
  data <- data %>%
    dplyr::filter(simulate_model == fit_model & simulate_model == model) %>%
    dplyr::mutate(
      input_param = .data[[paste0("input_param_", param_index)]],
      output_param = .data[[paste0("output_param_", param_index)]]
    )
  
  if (param_name == "τ") {
    if (min(data$input_param) > 1) {
      data <- data %>%
        dplyr::mutate(
          input_param = log10(input_param - 1),
          output_param = log10(output_param - 1)
        )
      param_name <- "τ-1"
    }
    else if ((min(data$input_param) <= 1)) {
      data <- data %>%
        dplyr::mutate(
          input_param = log10(input_param),
          output_param = log10(output_param)
        )
    }
    
    set.seed(123)
    x <- rexp(100, rate = rate)
    
    norm <- data.frame(
      x = log10(x),
      y = log10(x) + sample(c(-1, 1), 100, replace = TRUE) * log10(rexp(100, rate = rate))
    )
    
    x_min <- floor(quantile(norm$x, 0.05, na.rm = TRUE))
    x_max <- ceiling(quantile(norm$x, 0.95, na.rm = TRUE))
    
    y_min <- floor(quantile(norm$y, 0.05, na.rm = TRUE))
    y_max <- ceiling(quantile(norm$y, 0.95, na.rm = TRUE))
  } else {
    x <- runif(100, 0, 1)
    norm <- data.frame(
      x = x,
      y = x + rnorm(100, 0, 0.1)
    )
    
    x_min <- trunc(min(norm$x))
    if (max(norm$x) < 1) {
      x_max <- 1
    }
    else {
      x_max <- trunc(max(norm$x))
    }
    
    y_min <- trunc(min(norm$y))
    if (max(norm$y) < 1) {
      y_max <- 1
    }
    else {
      y_max <- trunc(max(norm$y))
    }
  }

  plot <- ggplot2::ggplot(data, aes(x = input_param, y = output_param)) +
    ggplot2::geom_smooth(
      data = norm,
      mapping = aes(x = x, y = y),
      method = "lm", color = "#55c186", se = FALSE
    ) +
    ggplot2::geom_point(
      data = norm,
      mapping = aes(x = x, y = y),
      color = "#55c186",  
      alpha = 0.5, shape = 1
    ) +
    ggplot2::geom_point(color = "#053562") +  # 绘制散点
    ggplot2::scale_y_continuous(
      limits = c(y_min, y_max + 0.1), expand = c(0, 0),
      labels = 
        if (grepl("τ", param_name)) {
          function(x) parse(text = paste0("10^", x))
        } 
        else {
          waiver()  
        }
    ) +
    ggplot2::scale_x_continuous(
      limits = c(x_min, x_max + 0.1), expand = c(0, 0),
      labels = 
        if (grepl("τ", param_name)) {
          function(x) parse(text = paste0("10^", x))
        } 
        else {
          waiver()  
        }
    ) +
    ggplot2::labs(
      x = paste("simulated", param_name),
      y = paste("fit", param_name),
      title = paste(model, "Model")
    ) +
    papaja::theme_apa() +
    ggplot2::theme(
      text = element_text(
        family = "serif", 
        face = "bold",
        size = 15
      ),
      axis.text = element_text(
        color = "black",
        family = "serif", 
        face = "plain",
        size = 10
      ),
      plot.title = element_text(
        family = "serif", face = "bold",
        size = 15, hjust = 0.5
      ),
      plot.margin = margin(t = 1, r = 1, b = 1, l = 1),
    )
  
  ggplot2::ggsave(
    plot = plot,
    filename = filename, 
    width = 4, height = 3
  ) 
  
  return(plot)
}

eta

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "TD",
  param_index = 1,
  param_name = "η",
  filename = "../FIGURE/1_TD_eta.png"
)

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "RSTD",
  param_index = 1,
  param_name = "η-",
  filename = "../FIGURE/2_RSTD_eta1.png"
)

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "RSTD",
  param_index = 2,
  param_name = "η+",
  filename = "../FIGURE/2_RSTD_eta2.png"
)

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "Utility",
  param_index = 1,
  param_name = "η",
  filename = "../FIGURE/3_Utility_eta.png"
)

TD_etaUtility_eta

RSTD_eta1RSTD_eta2

data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "TD") 
cor(data$input_param_1, data$output_param_1)
## [1] 0.5206546
data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "RSTD") 
cor(data$input_param_1, data$output_param_1)
## [1] 0.3172339
data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "RSTD") 
cor(data$input_param_2, data$output_param_2)
## [1] 0.4627732
data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "Utility") 
cor(data$input_param_1, data$output_param_1)
## [1] 0.4622493

gamma

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "Utility",
  param_index = 2,
  param_name = "γ",
  filename = "../FIGURE/3_Utility_gamma.png"
)

RL Models

data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "Utility") 
cor(data$input_param_2, data$output_param_2)
## [1] 0.6682462

tau

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "TD",
  param_index = 2,
  param_name = "τ",
  rate = 1,
  filename = "../FIGURE/1_TD_tau.png"
)

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "RSTD",
  param_index = 3,
  param_name = "τ",
  rate = 1,
  filename = "../FIGURE/2_RSTD_tau.png"
)

plot_param_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  model = "Utility",
  param_index = 3,
  param_name = "τ",
  rate = 1,
  filename = "../FIGURE/3_Utility_tau.png"
)

TD_tauRSTD_tauUtility_tau

data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "TD") 
cor(data$input_param_2, data$output_param_2)
## [1] 0.5453743
data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "RSTD") 
cor(data$input_param_3, data$output_param_3)
## [1] 0.4646312
data <- read.csv("../OUTPUT/result_recovery.csv") %>%
  dplyr::filter(simulate_model == fit_model & simulate_model == "Utility") 
cor(data$input_param_3, data$output_param_3)
## [1] 0.5726734

Model Recovery

Plot Function
plot_model_rcv <- function(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  matrix_type,
  filename
){
  data <- data %>%
    dplyr::select(simulate_model, fit_model, iteration, BIC) %>%
    tidyr::pivot_wider(names_from = "fit_model", values_from = "BIC") %>%
    dplyr::mutate(
      TD_score = ifelse(TD == pmin(TD, RSTD, Utility), 1, 0),
      RSTD_score = ifelse(RSTD == pmin(TD, RSTD, Utility), 1, 0),
      Utility_score = ifelse(Utility == pmin(TD, RSTD, Utility), 1, 0)
    ) %>% 
    dplyr::select(simulate_model, TD_score, RSTD_score, Utility_score) %>%
    dplyr::group_by(simulate_model) %>%
    dplyr::summarise(
      TD = round(mean(TD_score), 2),
      RSTD = round(mean(RSTD_score), 2),
      Utility = round(mean(Utility_score), 2),
    ) %>%
    dplyr::ungroup() %>%
    dplyr::arrange(factor(simulate_model, levels = c("TD", "RSTD", "Utility"))) %>%
    tidyr::pivot_longer(
      cols = -simulate_model, 
      names_to = "fit_model", 
      values_to = "value"
    ) %>%
    dplyr::mutate(
      simulate = factor(simulate_model, levels = c("TD", "RSTD", "Utility")),
      fit = factor(fit_model, levels = c("TD", "RSTD", "Utility")),
    ) 
  
  switch(
    matrix_type,
    "confusion" = {
      data <- data
      title <- "Confusion Matrix: P(fit model | simulated model)"
      hjust <- 1.3
    },
    "inversion" = {
      data <- data %>%
        dplyr::group_by(fit_model) %>%
        dplyr::mutate(value = round(value / sum(value), 2)) %>%
        dplyr::ungroup()
      title <- "Inversion Matrix: P(simulated model | fit model)"
      hjust <- 1.5
    }
  )
  
  plot <- ggplot2::ggplot(data, aes(x = simulate, y = fit, fill = value)) +
    ggplot2::geom_tile() +
    ggplot2::scale_fill_gradient2(
      low = "#003161", mid = "#55c186", high = "#f0de36",
      limits = c(0, 1), 
      midpoint = 0.4       
    ) + 
    ggplot2::labs(
      title = title,
      x = "simulate model", 
      y = "fit model"
    ) + 
    ggplot2::geom_text(aes(label = value), size = 10, color = "white") +
    ggplot2::theme_minimal() +
    ggplot2::theme(
      legend.position = "none",
      panel.background = ggplot2::element_rect(fill = "white", color = NA),
      plot.background = ggplot2::element_rect(fill = "white", color = NA),
      plot.margin = margin(t = 10, r = 10, b = 10, l = 10),
      plot.title = element_text(
        family = "serif", face = "bold",
        size = 25, hjust = hjust, margin = margin(b = 20)
      ),
      axis.title.y = element_text(
        family = "serif", face = "bold",
        size = 25, margin = margin(r = 20)
      ),
      axis.title.x = element_text(
        family = "serif", face = "bold",
        size = 25, margin = margin(t = 20)
      ),
      axis.text = element_text(
        color = "black", family = "serif",  face = "bold",
        size = 20
      )
    )
  
  ggplot2::ggsave(
    plot = plot,
    filename = filename, 
    width = 8, height = 6
  )
  
  return(plot)
}
plot_model_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  matrix_type = "confusion",
  filename = "../FIGURE/matrix_confusion.png"
)

plot_model_rcv(
  data = read.csv("../OUTPUT/result_recovery.csv"),
  matrix_type = "inversion",
  filename = "../FIGURE/matrix_inversion.png"
)

confusion_matrixinversion_matrix