Skip to contents

Data generation

The example simulation experiment data (exposure X, outcome Y, and potential mediators M) was generated as follows.

For each subject \(i = 1, ..., N\):

  • \(X_i \stackrel{i.i.d}{\sim} N(0, 1)\)
  • \(M_{i,v} = a_v X_i + e_{1_{i,v}}\), where \(e_{1_{i,v}} \stackrel{i.i.d}{\sim} N(0, 1), v = 1, ..., V\)
  • \(Y_i = X_i + \sum_v b_{1_{v}} M_{i,v} + \sum_v b_{2_{v}} X_i \times M_{i,v} + e_{2_{i}}\), where \(e_{2_{i}} \stackrel{i.i.d}{\sim} N(0, 1)\)

The first three M variables (\(M_1,M_2,M_3\)) are set to be the true mediators (i.e., having non-zero \(a\) and \(b_1\) coefficients), \(X \times M_1\) is set to be the true exposure-by-mediator interaction term (i.e., having non-zero \(b_2\) coefficients), and all other coefficients are set to be 0. The effect size (ES) represents the value of \(a, b_1, b_2\) of the truth, which are \(a_1, a_2, a_3, b_{1_{1}}, b_{1_{2}}, b_{1_{3}}, b_{2_{1}}\) in our case.

For example, with N = 200, V = 100 and ES = 1, we can generate the example simulation experiment data using the following code.

## load package
library(XMInt)

## data generation
data = dat_gen(N = 200, V = 50, es = 1, seed = 1)
X = data$X; Y = data$Y; M = data$M

Simulation setting

  • True mediators: \(M_1, M_2, M_3\)
  • True exposure-by-mediator interaction: \(X \times M_1\)
  • Sample size: N = 100, 200, 400
  • Number of potential mediators: V = 50, 100, 200, 400
  • Effect size: ES = 0.25, 0.5, 0.75, 1
  • Number of simulation runs: 100

Code

library(XMInt)

library(parallel)
library(doParallel) 
library(foreach) 
library(iterators) 

library(dplyr)


nCores <- 4
registerDoParallel(nCores) 

truth_med = c("M1","M2","M3")
truth_int = c("M1")



# --- simulation for N = 100 --- 

N = 100
sim.times = 100
V_list = c(50,100,200,400)
ES_list = c(0.25,0.5,0.75,1)

# It may take a while to run 100 simulations. 
# Here we use 4 simulations to illustrate the simulation code. 
N = 100;sim.times = 4;V_list = c(50,100);ES_list = c(0.5,0.75,1)

out = foreach(V = V_list,.packages = c("XMInt")) %:%
  foreach(ES = ES_list) %:% 
  foreach(s = 1:sim.times) %dopar% {
    # do simulation 
    cat("Case : N = ",N, ", V = ", V, ", ES = ", ES, ", Simulation = ", s, "\n", sep = "")
    # data 
    data = dat_gen(N,V,ES,seed = s)
    X = data$X
    Y = data$Y
    M = data$M
    
    rslt = XMInt_select(X,Y,M)
    rslt
  }

# list order in out: out[[V]][[ES]][[sim.times]]

# rename the list
V_name_list = paste0("V",V_list)
names(out) = V_name_list
for (V in 1:length(V_name_list)) {
  names(out[[V]]) = paste0("ES",ES_list)
  for (ES in 1:length(ES_list)) {
    names(out[[V]][[ES]]) = paste0("sim",1:sim.times)
  }
}

sim.error = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          re.error = data.frame(N = N, V = v, ES = es, sim = sim)
        }
      }))
    }))
  }))

# compute tpr, fdr
re = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          out = rlist::list.remove(out[[paste0("V",v)]][[paste0("ES",es)]], paste0("sim",sim))
        }
        return(data.frame(
          N = N, V = v, ES = es, sim = sim, 
          sim_med = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator, collapse = "_"), 
          sim_int = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction, collapse = "_"),  
          tpr_med = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/length(truth_med), 
          tpr_int = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/length(truth_int), 
          fdr_med = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator)), 
          fdr_int = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction)))
          )
      }))
    }))
  }))

# save results
if (is.null(sim.error)) {re_N100 = re} else {
  sim.error = sim.error %>% as.data.frame() %>% mutate(id = paste(N,V,ES,sim,sep = "_"))
  re_N100 = re %>% mutate(id = paste(N,V,ES,sim,sep = "_")) %>% filter(!id %in% sim.error$id) %>% select(-id)
}



# --- simulation for N = 200 --- 

N = 200
sim.times = 100
V_list = c(50,100,200,400)
ES_list = c(0.25,0.5,0.75,1)

# To illustrate simulation code, use 4 simulations
N = 200;sim.times = 4;V_list = c(50,100);ES_list = c(0.5,0.75,1)

out = foreach(V = V_list,.packages = c("XMInt")) %:%
  foreach(ES = ES_list) %:% 
  foreach(s = 1:sim.times) %dopar% {
    # do simulation 
    cat("Case : N = ",N, ", V = ", V, ", ES = ", ES, ", Simulation = ", s, "\n", sep = "")
    # data 
    data = dat_gen(N,V,ES,seed = s)
    X = data$X
    Y = data$Y
    M = data$M
    
    rslt = XMInt_select(X,Y,M)
    rslt
  }

# list order in out: out[[V]][[ES]][[sim.times]]

# rename the list
V_name_list = paste0("V",V_list)
names(out) = V_name_list
for (V in 1:length(V_name_list)) {
  names(out[[V]]) = paste0("ES",ES_list)
  for (ES in 1:length(ES_list)) {
    names(out[[V]][[ES]]) = paste0("sim",1:sim.times)
  }
}

sim.error = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          re.error = data.frame(N = N, V = v, ES = es, sim = sim)
        }
      }))
    }))
  }))

# compute tpr, fdr
re = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          out = rlist::list.remove(out[[paste0("V",v)]][[paste0("ES",es)]], paste0("sim",sim))
        }
        return(data.frame(
          N = N, V = v, ES = es, sim = sim, 
          sim_med = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator, collapse = "_"), 
          sim_int = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction, collapse = "_"),  
          tpr_med = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/length(truth_med), 
          tpr_int = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/length(truth_int), 
          fdr_med = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator)), 
          fdr_int = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction)))
        )
      }))
    }))
  }))

# save results
if (is.null(sim.error)) {re_N200 = re} else {
  sim.error = sim.error %>% as.data.frame() %>% mutate(id = paste(N,V,ES,sim,sep = "_"))
  re_N200 = re %>% mutate(id = paste(N,V,ES,sim,sep = "_")) %>% filter(!id %in% sim.error$id) %>% select(-id)
}



# --- simulation for N = 400 --- 

N = 400
sim.times = 100
V_list = c(50,100,200,400)
ES_list = c(0.25,0.5,0.75,1)

# To illustrate simulation code, use 4 simulations
N = 400;sim.times = 4;V_list = c(50,100);ES_list = c(0.5,0.75,1)

out = foreach(V = V_list,.packages = c("XMInt")) %:%
  foreach(ES = ES_list) %:% 
  foreach(s = 1:sim.times) %dopar% {
    # do simulation 
    cat("Case : N = ",N, ", V = ", V, ", ES = ", ES, ", Simulation = ", s, "\n", sep = "")
    # data 
    data = dat_gen(N,V,ES,seed = s)
    X = data$X
    Y = data$Y
    M = data$M
    
    rslt = XMInt_select(X,Y,M)
    rslt
  }

# list order in out: out[[V]][[ES]][[sim.times]]

# rename the list
V_name_list = paste0("V",V_list)
names(out) = V_name_list
for (V in 1:length(V_name_list)) {
  names(out[[V]]) = paste0("ES",ES_list)
  for (ES in 1:length(ES_list)) {
    names(out[[V]][[ES]]) = paste0("sim",1:sim.times)
  }
}

sim.error = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          re.error = data.frame(N = N, V = v, ES = es, sim = sim)
        }
      }))
    }))
  }))

# compute tpr, fdr
re = 
  do.call(rbind,lapply(V_list, function(v){
    do.call(rbind,lapply(ES_list, function(es){
      do.call(rbind, lapply(1:sim.times, function(sim){
        if (class(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]) == "try-error") {
          out = rlist::list.remove(out[[paste0("V",v)]][[paste0("ES",es)]], paste0("sim",sim))
        }
        return(data.frame(
          N = N, V = v, ES = es, sim = sim, 
          sim_med = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator, collapse = "_"), 
          sim_int = paste(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction, collapse = "_"),  
          tpr_med = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/length(truth_med), 
          tpr_int = length(which(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/length(truth_int), 
          fdr_med = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator %in% truth_med))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_mediator)), 
          fdr_int = length(which(!out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction %in% truth_int))/max(1,length(out[[paste0("V",v)]][[paste0("ES",es)]][[paste0("sim",sim)]]$selected_interaction)))
        )
      }))
    }))
  }))

# save results
if (is.null(sim.error)) {re_N400 = re} else {
  sim.error = sim.error %>% as.data.frame() %>% mutate(id = paste(N,V,ES,sim,sep = "_"))
  re_N400 = re %>% mutate(id = paste(N,V,ES,sim,sep = "_")) %>% filter(!id %in% sim.error$id) %>% select(-id)
}


stopImplicitCluster()






# --- plot results --- 

library(tidyverse)

re_list = ls(pattern = "re_", all.names = TRUE)
plt_dat = do.call(rbind,lapply(re_list, get)) %>% 
  mutate(tpr_med = tpr_med * 100,
         fdr_med = fdr_med * 100, 
         tpr_int = tpr_int * 100, 
         fdr_int = fdr_int * 100)

## mean
pr_mean_dat = 
  plt_dat %>% 
  group_by(N,V,ES) %>% 
  summarise(mean_tpr_med = mean(tpr_med),
            mean_tpr_int = mean(tpr_int),
            mean_fdr_med = mean(fdr_med),
            mean_fdr_int = mean(fdr_int));pr_mean_dat %>% data.frame()

## tpr, med
sim.tpr.plt.med.mean = 
  pr_mean_dat %>% 
  mutate(
    N = as.character(N),
    V = factor(V, levels = c(50,100,200,400),labels = c("V = 50","V = 100","V = 200","V = 400")),
    ES = as.character(ES)
  ) %>% 
  ggplot(aes(y = mean_tpr_med,x = N, color = ES, group = ES, shape = ES)) + 
  geom_point(size = 2) +
  geom_line() + 
  facet_grid(V~., scales = "free") + 
  #ylim(0, 100) + 
  scale_y_continuous(limits = c(0, 100), breaks = seq(0, 100, by = 20)) +
  geom_hline(yintercept = 80, linetype = "dashed") +
  labs(y = "Avg. True Positive Rate (%)", x = "Sample Size (N)",
       title = "Avg. True Positive Rate: Mediators") +
  theme_bw(); sim.tpr.plt.med.mean

## tpr, int
sim.tpr.plt.int.mean = 
  pr_mean_dat %>% 
  mutate(
    N = as.character(N),
    V = factor(V, levels = c(50,100,200,400),labels = c("V = 50","V = 100","V = 200","V = 400")),
    ES = as.character(ES)
  ) %>% 
  ggplot(aes(y = mean_tpr_int,x = N, color = ES, group = ES, shape = ES)) + 
  geom_point(size = 2) +
  geom_line() + 
  facet_grid(V~., scales = "free") + 
  #ylim(0, 100) + 
  scale_y_continuous(limits = c(0, 100), breaks = seq(0, 100, by = 20)) + 
  geom_hline(yintercept = 80, linetype = "dashed") + 
  labs(y = "Avg. True Positive Rate (%)", x = "Sample Size (N)",
       title = "Avg. True Positive Rate: Interaction") +
  theme_bw(); sim.tpr.plt.int.mean


## fdr, med
sim.fdr.plt.med.mean = 
  pr_mean_dat %>% 
  mutate(
    N = as.character(N),
    V = factor(V, levels = c(50,100,200,400),labels = c("V = 50","V = 100","V = 200","V = 400")),
    ES = as.character(ES)
  ) %>% 
  ggplot(aes(y = mean_fdr_med,x = N, color = ES, group = ES, shape = ES)) + 
  geom_point(size = 2) +
  geom_line() + 
  #ylim(0, 100) + 
  scale_y_continuous(limits = c(0, 100), breaks = seq(0, 100, by = 20)) + 
  geom_hline(yintercept = 10, linetype = "dashed") +
  facet_grid(V~., scales = "free") + 
  labs(y = "Avg. False Discovery Rate (%)", x = "Sample Size (N)",
       title = "Avg. False Discovery Rate: Mediators") +
  theme_bw(); sim.fdr.plt.med.mean


## fdr, int
sim.fdr.plt.int.mean = 
  pr_mean_dat %>% 
  mutate(
    N = as.character(N),
    V = factor(V, levels = c(50,100,200,400),labels = c("V = 50","V = 100","V = 200","V = 400")),
    ES = as.character(ES)
  ) %>% 
  ggplot(aes(y = mean_fdr_int,x = N, color = ES, group = ES, shape = ES)) + 
  geom_point(size = 2) +
  geom_line() + 
  #ylim(0, 100) + 
  scale_y_continuous(limits = c(0, 100), breaks = seq(0, 100, by = 20)) + 
  geom_hline(yintercept = 10, linetype = "dashed") +
  facet_grid(V~., scales = "free") + 
  labs(y = "Avg. False Discovery Rate (%)", x = "Sample Size (N)",
       title = "Avg. False Discovery Rate: Interaction") +
  theme_bw(); sim.fdr.plt.int.mean


# combine plots
library(pdp)
#pdf("fig2.pdf", height = 8, width = 10) # save as pdf
grid.arrange(sim.tpr.plt.med.mean, sim.fdr.plt.med.mean, 
             sim.tpr.plt.int.mean, sim.fdr.plt.int.mean, 
             nrow = 2)
#dev.off()