Advertisement
celestialgod

count_cpp

Nov 18th, 2015
227
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 3.12 KB | None | 0 0
  1. # data generation
  2. set.seed(100)
  3. numData = 1e4
  4. dat = replicate(numData, as.integer(sample(1:100, 10)))
  5. # method 1
  6. st = proc.time()
  7. count = vector('integer', numData)
  8. for (i in 1:ncol(dat))
  9.   count[i] = sum(colSums(matrix(dat %in% dat[,i], 10)) >= 5) - 1
  10. proc.time() - st
  11. #   user  system elapsed
  12. #  22.84    0.91   23.74
  13.  
  14. # method 2 by Rcpp
  15. library(Rcpp)
  16. library(RcppArmadillo)
  17. sourceCpp(code = '
  18. // [[Rcpp::depends(RcppArmadillo)]]
  19. #define ARMA_DONT_USE_CXX11
  20. #include <RcppArmadillo.h>
  21. using namespace Rcpp;
  22. using namespace arma;
  23.  
  24. // [[Rcpp::export]]
  25. Col<int> count_cpp(IntegerMatrix xr) {
  26. Mat<int> x(xr.begin(), xr.nrow(), xr.ncol(), false);
  27. Col<int> out = zeros< Col<int> >(xr.ncol());
  28. int count;
  29. for (uword i = 0; i < x.n_cols; i++)
  30. {
  31.   for (uword j = 0; j < x.n_cols; j++)
  32.   {
  33.     count = 0;
  34.     for (uword k = 0; k < x.n_rows; k++)
  35.       for (uword l = 0; l < x.n_rows; l++)
  36.         if (x(k, j) == x(l, i))
  37.           count++;
  38.     if (count >= 5)
  39.       out(i)++;
  40.   }
  41. }
  42. return out;
  43. }')
  44. st = proc.time()
  45. count2 = count_cpp(dat) - 1
  46. proc.time() - st
  47. #   user  system elapsed
  48. #  7.28    0.01    7.30
  49.  
  50. # method 3 by Rcpp and RcppParallel
  51. library(Rcpp)
  52. library(RcppArmadillo)
  53. library(RcppParallel)
  54. sourceCpp(code = '
  55. // [[Rcpp::depends(RcppArmadillo, RcppParallel)]]
  56. // [[Rcpp::plugins("cpp11")]]
  57. #include <RcppArmadillo.h>
  58. #include <RcppParallel.h>
  59. using namespace Rcpp;
  60. using namespace arma;
  61. using namespace RcppParallel;
  62.  
  63. struct CountWorker: public Worker {
  64.  Mat<int>& tableMat;
  65.  Mat<int>& data;
  66.  Col<int>& output;
  67.  CountWorker(Mat<int>& tableMat, Mat<int>& data, Col<int>& output) :
  68.     tableMat(tableMat), data(data), output(output) {}
  69.  void operator()(std::size_t begin, std::size_t end) {
  70.    for (std::size_t i = begin; i < end; i++)
  71.    {
  72.      uvec tmp = find(sum(tableMat.cols(conv_to<uvec>::from(data.col(i))-1), 1) >= 5);
  73.      output(i) = tmp.n_elem;
  74.    }
  75.  }
  76. };
  77.  
  78. // [[Rcpp::export]]
  79. Col<int> count_cpp(IntegerMatrix xr, IntegerVector tableVecr) {
  80.  Mat<int> x(xr.begin(), xr.nrow(), xr.ncol(), false);
  81.  Col<int> tableVec(tableVecr.begin(), tableVecr.size(), false);
  82.  Mat<int> tableMat = zeros< Mat<int> >(x.n_cols, tableVec.n_elem);
  83.  Col<int> output = zeros< Col<int> >(x.n_cols);
  84.  for (uword i = 0; i < x.n_cols; i++)
  85.    for (uword j = 0; j < x.n_rows; j++)
  86.      tableMat(i, x(j, i)-1)++;
  87.  CountWorker countWorker(tableMat, x, output);
  88.  parallelFor(0, x.n_cols, countWorker);
  89.  return output;
  90. }')
  91. st = proc.time()
  92. tmp = unique(sort(dat))
  93. count3 = count_cpp(dat, tmp) - 1
  94. proc.time() - st
  95. #   user  system elapsed
  96. #   1.16    0.06    0.28
  97.  
  98. # method 4 modified version of the codes written by Edster
  99. st = proc.time()
  100. Y = unique(sort(dat))
  101. Z = matrix(0, ncol(dat), length(Y))
  102. count4 = vector('numeric', ncol(dat))
  103. for(i in 1:ncol(dat))
  104.   for(j in 1:10)
  105.     Z[i, dat[j, i]] = 1
  106. for(i in 1:ncol(dat))
  107.   count4[i] = sum(rowSums(Z[, dat[,i]]) >= 5) - 1
  108. proc.time() - st
  109. #   user  system elapsed
  110. #   7.10    1.31    8.47
  111.  
  112. all.equal(count, as.vector(count2))
  113. all.equal(count, as.vector(count3))
  114. all.equal(count, as.vector(count4))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement