celestialgod

transitProbMatrix

Mar 13th, 2015
360
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 5.83 KB | None | 0 0
  1. library(data.table)
  2. library(dplyr)
  3. library(magrittr)
  4. # data generation
  5. N_patient = 3000
  6. dat = sample(48, N_patient, replace = TRUE) %>% {
  7.     cbind(rep(1:N_patient, times=.), unlist(sapply(., seq, from = 1)))
  8.     } %>% cbind(replicate(6, sample(0:1, nrow(.), TRUE))) %>%
  9.     tbl_dt() %>% setnames(c("id", "duration", paste0("M_", 1:6))) %>%
  10.     arrange(id, duration)
  11.  
  12. st = proc.time()
  13. transitMatrix_eachTime = dat %>% split(.$id) %>% lapply(function(dd){
  14.     out = lapply(1:47, function(x) matrix(0, 6, 6))
  15.     if (nrow(dd) > 1)
  16.     {
  17.       tmp = dd %>% select(3:8) %>% as.matrix()
  18.       for(i in 2:nrow(dd))
  19.       {
  20.         transitMatrix = matrix(0, 6, 6);
  21.         if (sum(tmp[i-1,]) > 0)
  22.           transitMatrix[tmp[i-1,]==1,] =
  23.             t(replicate(sum(tmp[i-1,]), tmp[i,]))
  24.         out[[i-1]] = transitMatrix
  25.       }
  26.     }
  27.     out
  28.   }) %>%
  29.   Reduce(function(x, y) lapply(1:length(x),
  30.     function(v) x[[v]]+y[[v]]), .) %>%
  31.   lapply(function(x) x / ifelse(rowSums(x)> 0, rowSums(x), 1))
  32. proc.time() - st
  33. #   user  system elapsed
  34. #  32.10    0.19   33.26
  35.  
  36. library(Rcpp)
  37. library(RcppArmadillo)
  38. sourceCpp(code = '
  39. // [[Rcpp::depends(RcppArmadillo)]]
  40. #include <RcppArmadillo.h>
  41. using namespace Rcpp;
  42. using namespace arma;
  43.  
  44. // [[Rcpp::export]]
  45. List transitMatrix_f(umat M, uword maxDuration)
  46. {
  47.  uword maxDuration_id = M.n_rows;
  48.  List transitMatrixList(maxDuration-1);
  49.  umat transitMatrix(M.n_cols, M.n_cols);
  50.  urowvec previous_M(M.n_cols);
  51.  for (uword i = 1; i < maxDuration; i++)
  52.  {
  53.    transitMatrix.zeros();
  54.    if ( i < maxDuration_id)
  55.    {
  56.      previous_M = M.row(i-1);
  57.      if (any(previous_M==1))
  58.        transitMatrix.rows(find(previous_M==1)) = repmat(M.row(i), sum(previous_M), 1);
  59.    }
  60.    transitMatrixList[i-1] = wrap(transitMatrix);
  61.  }
  62.  return transitMatrixList;
  63. }')
  64. library(RcppEigen)
  65. sourceCpp(code = '
  66. // [[Rcpp::depends(RcppEigen)]]
  67. #include <RcppEigen.h>
  68. using namespace Rcpp;
  69. using Eigen::Map;
  70. using Eigen::MatrixXd;
  71. using Eigen::VectorXd;
  72.  
  73. // [[Rcpp::export]]
  74. void list_sum_f(List Xr, List Yr) {
  75.  for(int i = 0; i < Xr.size(); i++)
  76.       Yr[i] = as< Map<MatrixXd> >(Xr[i]) + as< Map<MatrixXd> >(Yr[i]);
  77. }
  78.  
  79. // [[Rcpp::export]]
  80. List listAddition(List Xr) {
  81.  int n = Xr.size();
  82.  List list_sum = Xr[0];
  83.  for(int j = 1; j < n; j++)
  84.       list_sum_f(Xr[j], list_sum);
  85.  return list_sum;
  86. }')
  87.  
  88. st = proc.time()
  89. maxDuration = max(dat$duration)
  90. transitMatrix_eachTime2 = dat %>% split(.$id) %>% lapply(function(x){
  91.     transitMatrix_f(x %>% select(-id, -duration) %>% as.matrix(),
  92.       maxDuration)
  93.   }) %>% listAddition() %>%
  94.   lapply(function(x) x / ifelse(rowSums(x)> 0, rowSums(x), 1))
  95. proc.time() - st
  96. #   user  system elapsed
  97. #  14.68    0.20   15.27
  98.  
  99. library(reshape2)
  100. library(Matrix)
  101. st = proc.time()
  102. dat_previous = dat %>% group_by(id) %>%
  103.   mutate_(.dots = paste0("c(0, M_", 1:6, "[-length(M_1)])")) %>%
  104.   setnames(old = tail(names(.), 6), new = paste0("M_", 1:6, "p"))
  105.  
  106. dat_transform_1 = dat_previous %>%
  107.   melt(id = c("id", "duration"), measure = paste0("M_", 1:6)) %>%
  108.   filter(value == 1, duration > 1) %>% select(-value) %>%
  109.   transform(variable = as.numeric(substr(variable, 3, 3))) %>%
  110.   setnames("variable", "M") %>% setkey(id, duration)
  111.  
  112. dat_transform_2 = dat_previous %>%
  113.   melt(id = c("id","duration"), measure = paste0("M_", 1:6, "p")) %>%
  114.   filter(value == 1) %>% select(-value) %>%
  115.   mutate(variable = as.numeric(substr(variable, 3, 3))) %>%
  116.   setnames("variable", "M_p") %>% setkey(id, duration)
  117.  
  118. dat_combined = dat_transform_1[dat_transform_2, allow.cartesian=TRUE] %>% filter(!is.na(M), !is.na(M_p))
  119.  
  120. transitMatrix_eachTime3 = dat_combined %>% group_by(duration, M, M_p) %>%
  121.   summarise(count = n()) %>% group_by(duration, M_p) %>%
  122.   mutate(transitProb = count / sum(count)) %>% ungroup() %>%
  123.   split(.$duration) %>%
  124.   lapply(function(x) spMatrix(6, 6, x$M_p, x$M, x$transitProb))
  125. proc.time() - st
  126. #   user  system elapsed
  127. #   2.15    0.31    2.62
  128.  
  129. st = proc.time()
  130. dat_transform = dat %>%
  131.   melt(id = c("id", "duration"), measure = paste0("M_", 1:6)) %>%
  132.   filter(value == 1) %>% select(-value) %>%
  133.   transform(variable = as.numeric(substr(variable, 3, 3))) %>%
  134.   setnames("variable", "M") %>% setkey(id, duration)
  135.  
  136. dat_combined = dat_transform %>% filter(duration > 1) %>%
  137.   inner_join(dat_transform %>% transform(duration = duration + 1),
  138.   by = c("id", "duration"))
  139.  
  140. transitMatrix_eachTime5 = dat_combined %>%
  141.   group_by(duration, M.x, M.y) %>% summarise(count = n()) %>%
  142.   group_by(duration, M.y) %>%
  143.   mutate(transitProb = count / sum(count)) %>% ungroup() %>%
  144.   split(.$duration) %>%
  145.   lapply(function(x) spMatrix(6, 6, x$M.y, x$M.x, x$transitProb))
  146. proc.time() - st
  147. #   user  system elapsed
  148. #   1.25    0.19    1.45
  149.  
  150. library(tidyr)
  151. st = proc.time()
  152. dat_transform = dat %>% gather(M, value, M_1:M_6) %>%
  153.   filter(value == 1) %>% select(-value) %>%
  154.   transform(M = as.numeric(substr(M, 3, 3)))
  155.  
  156. dat_combined = dat_transform %>% filter(duration > 1) %>%
  157.   inner_join(dat_transform %>% transform(duration = duration + 1),
  158.   by = c("id", "duration"))
  159.  
  160. transitMatrix_eachTime5 = dat_combined %>%
  161.   group_by(duration, M.x, M.y) %>% summarise(count = n()) %>%
  162.   group_by(duration, M.y) %>%
  163.   mutate(transitProb = count / sum(count)) %>% ungroup() %>%
  164.   split(.$duration) %>%
  165.   lapply(function(x) spMatrix(6, 6, x$M.y, x$M.x, x$transitProb))
  166. proc.time() - st
  167. #   user  system elapsed
  168. #   1.28    0.12    1.40
  169.  
  170. all.equal(transitMatrix_eachTime, transitMatrix_eachTime2)
  171. # TRUE
  172. all.equal(transitMatrix_eachTime, transitMatrix_eachTime3 %>%
  173.   lapply(as.matrix) %>% set_names(NULL))
  174. # TRUE
  175. all.equal(transitMatrix_eachTime, transitMatrix_eachTime4 %>%
  176.   lapply(as.matrix) %>% set_names(NULL))
  177. # TRUE
  178. all.equal(transitMatrix_eachTime, transitMatrix_eachTime5 %>%
  179.   lapply(as.matrix) %>% set_names(NULL))
  180. # TRUE
Advertisement
Add Comment
Please, Sign In to add comment