Skip to contents

Step 4

comparison <- binaryRL::fit_p(
  data = binaryRL::Mason_2024_G2,
  policy = c("off", "on"),
  model_name = c("TD", "RSTD", "Utility"),
  fit_model = list(binaryRL::TD, binaryRL::RSTD, binaryRL::Utility),
  estimate = c("MLE", "MAP"),
  lower = list(c(0, 0), c(0, 0, 0), c(0, 0, 0)),
  upper = list(c(1, 10), c(1, 1, 10), c(1, 1, 10)),
  priors = list(
    list(
      eta = function(x) { stats::dunif(x, min = 0, max = 1, log = TRUE) }, 
      tau = function(x) { stats::dexp(x, rate = 1, log = TRUE) }
    ), 
    list(
      eta = function(x) { stats::dunif(x, min = 0, max = 1, log = TRUE) }, 
      eta = function(x) { stats::dunif(x, min = 0, max = 1, log = TRUE) }, 
      tau = function(x) { stats::dexp(x, rate = 1, log = TRUE) }
    ), 
    list(
      eta = function(x) { stats::dunif(x, min = 0, max = 1, log = TRUE) }, 
      gamma = function(x) { stats::dunif(x, min = 0, max = 1, log = TRUE) }, 
      tau = function(x) { stats::dexp(x, rate = 1, log = TRUE) }
    )
  ),
  iteration_i = 100,
  iteration_g = 10,
  seed = 123,
  nc = 16,
  #algorithm = "L-BFGS-B"     # Gradient-Based (stats::optim)
  #algorithm = "GenSA"        # Simulated Annealing (GenSA::GenSA)
  #algorithm = "GA"           # Genetic Algorithm (GA::ga)
  #algorithm = "DEoptim"      # Differential Evolution (DEoptim::DEoptim)
  #algorithm = "PSO"          # Particle Swarm Optimization (pso::psoptim)
  #algorithm = "Bayesian"     # Bayesian Optimization (mlrMBO::mbo)
  #algorithm = "CMA-ES"       # Covariance Matrix Adapting (`cmaes::cma_es`)
  # Nonlinear Optimization (nloptr::nloptr)
  algorithm = c("NLOPT_GN_MLSL", "NLOPT_LN_BOBYQA")
)

result <- dplyr::bind_rows(comparison)

write.csv(result, "../OUTPUT/result_comparison.csv", row.names = FALSE)

Plot

Read Result

data <- read.csv("../OUTPUT/result_comparison.csv") %>%
  dplyr::select(-dplyr::starts_with("param_")) %>%
  tidyr::pivot_longer(
    cols = c(ACC, LogL, AIC, BIC),
    names_to = "metric", 
    values_to = "value"
  ) %>%
  dplyr::group_by(Subject, metric) %>%
  dplyr::mutate(
    value_norm = (value - min(value)) / (max(value) - min(value) + 1e-10), 
    fit_model = factor(
      fit_model,
      levels = c("TD", "RSTD", "Utility")
    ),
    metric = factor(
      metric, 
      levels = c('ACC', 'LogL', 'AIC', 'BIC'),
      labels = c('ACC', '-LogL', 'AIC', 'BIC')
    )
  ) %>%
  dplyr::ungroup() %>%
  dplyr::arrange(Subject, fit_model)

Absolute

Plot Function
plots <- data %>%
  split(.$metric) %>%
  purrr::map(~{
    ggplot2::ggplot(.x, ggplot2::aes(x = fit_model, y = value, fill = fit_model)) +
      ggplot2::geom_bar(stat = "summary", fun = "mean", position = "dodge") +
      ggplot2::geom_errorbar(
        stat = "summary", fun.data = "mean_se",
        position = ggplot2::position_dodge(width = 0.9), width = 0.2
      ) +
      ggplot2::scale_fill_manual(values = c("#053562", "#55c186", "#f0de36")) +
      ggplot2::labs(x = "", y = "", title = .x$metric[1]) +
      ggplot2::coord_cartesian(
        ylim = c(
          mean(.x$value) - 0.7*sd(.x$value), 
          mean(.x$value) + 0.7*sd(.x$value))
      ) +
      papaja::theme_apa() +
      ggplot2::theme(
        legend.position = "none",
        plot.margin = margin(t = 0, r = 0, b = 0, l = 0),
        text = element_text(family = "serif", face = "bold", size = 15),
        axis.text = element_text(
          color = "black", family = "serif", face = "bold", size = 12
        ),
      )
  })

plot <- patchwork::wrap_plots(plots, ncol = 2) +
  patchwork::plot_annotation(
    title = "Model Comparison (absolute)",
    theme = ggplot2::theme(
      plot.title = ggplot2::element_text(
        family = "serif", face = "bold",
        size = 25, hjust = 0, margin = ggplot2::margin(b = 20)
      )
    )
  )


rm(plots)

ggplot2::ggsave(
  plot = plot, 
  filename = "../FIGURE/model_comparison(abs).png",
  width = 8, height = 6
)

plot

RL Models

Relative

Plot Function
plot <- ggplot2::ggplot(data, aes(x = metric, y = value_norm, fill = fit_model)) +
  ggplot2::geom_bar(stat = "summary", fun = "mean", position = "dodge") +
  ggplot2::geom_errorbar(
    stat = "summary", fun.data = "mean_se",
    position = position_dodge(width = 0.9), width = 0.2
  ) +
  ggplot2::labs(x = "", y = "", fill = "Model") +
  ggplot2::ggtitle("Model Comparison (relative)") +
  ggplot2::scale_fill_manual(values = c("#053562", "#55c186", "#f0de36")) +
  ggplot2::coord_cartesian(ylim = c(0, 0.80)) +
  ggplot2::scale_y_continuous(labels = scales::percent) +
  papaja::theme_apa() +
  ggplot2::theme(
    plot.margin = margin(t = 5, r = 0, b = 0, l = 0),
    plot.title = element_text(
      family = "serif", face = "bold",
      size = 25, hjust = -1.4
    ),
    legend.title = element_blank(),
    legend.position.inside = c(0.8, 0.75),  
    legend.justification = c(0, 0), 
    text = element_text(
      family = "serif", 
      face = "bold",
      size = 25
    ),
    axis.text = element_text(
      color = "black",
      family = "serif", 
      face = "bold",
      size = 20
    ),
  )  
  
ggplot2::ggsave(
  plot = plot,
  filename = "../FIGURE/model_comparison(rel).png", 
  width = 8, height = 6
)

plot

RL Models