Advertisement
Guest User

Untitled

a guest
Oct 17th, 2017
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 2.18 KB | None | 0 0
  1. bridge_net <- function(X, prior_scale = 1, chains = 4, models = 1,
  2.                        scale_seq = scale_seq, iter = 2000,
  3.                        adapt_delta = 0.80, max_treedepth = 10){
  4.   cores <- parallel::detectCores() - 2
  5.   if(sum(colSums(is.na(X))) > 0){
  6.     warning("NA values detected and removed", call. = FALSE)
  7.   }
  8.   X <- as.matrix(na.omit(scale(X)))
  9.   N <- length(X[,1])
  10.   K <- ncol(X)
  11.   mod_fit <- list()
  12.   # if fitting one model
  13.   for(i in 1:length(prior_scale)){
  14.     temp <- prior_scale[i]
  15.     if(temp <= 0){
  16.       stop("The scale must be positive. The default is set to 1")
  17.     }}
  18.   if(iter > 2000){
  19.     warning("Increasing iterations will make model fitting slower and may not be necessary for convergence!",
  20.             call. = FALSE)
  21.   }
  22.   if(models == 1){
  23.     stan_dat <- list(N = N, K = K, X = X, prior_scale = prior_scale)
  24.     # fit model
  25.     mod_fit   <- sampling(ridge_comp, data = stan_dat,
  26.                           chains = chains, iter = iter, cores = cores,
  27.                           control = list(adapt_delta = adapt_delta,
  28.                           max_treedepth = max_treedepth))
  29.   }
  30.   else if(models > 1){
  31.     if((as.numeric(models) ==  length(prior_scale)) == 0){
  32.       stop("Models must be same length as prior scale")
  33.     }
  34.     total <- length(prior_scale)
  35.     # create progress bar
  36.     pb <- txtProgressBar(min = 1, max = total, style = 3)
  37.    
  38.     for(i in 1:length(prior_scale)){
  39.       setTxtProgressBar(pb, i)
  40.       temp <- prior_scale[i]
  41.       stan_dat <- list(N = N, K = K, X = X, prior_scale = temp)
  42.       mod_fit[[i]] <- sampling(ridge_comp, data = stan_dat, cores = chains,
  43.                                chains = chains, iter = iter, refresh = 0,
  44.                                control = list(adapt_delta = adapt_delta,
  45.                                max_treedepth = max_treedepth))
  46.       }
  47.     temp_names <- rep("prior_scale", length(prior_scale))
  48.     mod_names <- paste(temp_names, prior_scale, sep = " ")
  49.     names(mod_fit) <- mod_names    
  50.     list(mod_fit = c(mod_fit), stan_dat = c(stan_dat))
  51.   }
  52. }
  53.  
  54.  
  55. test_run <- bridge_net(X, models = 3, prior_scale = c(.1, .5, 1),
  56.                        iter = 5000, chains = 4)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement