celestialgod

kernel function by using RcppParallel with RcppArmadillo

Aug 6th, 2015
459
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 4.03 KB | None | 0 0
  1. library(kernlab)
  2. library(Rcpp)
  3. library(RcppArmadillo)
  4. library(RcppParallel)
  5. sourceCpp(code = '
  6. // [[Rcpp::depends(RcppArmadillo)]]
  7. #include <RcppArmadillo.h>
  8. using namespace Rcpp;
  9. using namespace arma;
  10.  
  11. // [[Rcpp::export]]
  12. NumericMatrix kernelMatrix_Arma(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
  13.  uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
  14.  mat X(Xr.begin(), n, Xr.ncol(), false),
  15.      Center(Centerr.begin(), b, Centerr.ncol(), false),
  16.      KerX(X*Center.t());
  17.  colvec X_sq = sum(square(X), 1) / 2;
  18.  rowvec Center_sq = (sum(square(Center), 1)).t() / 2;
  19.  KerX.each_row() -= Center_sq;
  20.  KerX.each_col() -= X_sq;
  21.  KerX *= 1 / (sigma * sigma);
  22.  KerX = exp(KerX);
  23.  return wrap(KerX);
  24. }')
  25.  
  26. sourceCpp(code = '
  27. // [[Rcpp::depends(RcppArmadillo)]]
  28. #include <RcppArmadillo.h>
  29. #include <omp.h>
  30. // [[Rcpp::plugins(openmp)]]
  31. using namespace Rcpp;
  32. using namespace arma;
  33.  
  34. // [[Rcpp::export]]
  35. NumericMatrix kernelMatrix_openmp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
  36.  omp_set_num_threads(omp_get_max_threads());
  37.  uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
  38.  mat X(Xr.begin(), n, Xr.ncol(), false);
  39.  mat Center(Centerr.begin(), b, Centerr.ncol(), false);
  40.  mat KerX(n, b);
  41.  #pragma omp parallel private(row_index, col_index)
  42.  for (row_index = 0; row_index < n; row_index++)
  43.  {
  44.    #pragma omp for nowait
  45.    for (col_index = 0; col_index < b; col_index++)
  46.    {
  47.      KerX(row_index, col_index) = exp(sum(square(X.row(row_index)
  48.         - Center.row(col_index))) / (-2.0 * sigma * sigma));
  49.    }
  50.  }
  51.  return wrap(KerX);
  52. }')
  53.  
  54. sourceCpp(code = '
  55. // [[Rcpp::depends(RcppArmadillo, RcppParallel)]]
  56. #define ARMA_DONT_USE_CXX11
  57. #include <RcppArmadillo.h>
  58. #include <RcppParallel.h>
  59. using namespace Rcpp;
  60. using namespace arma;
  61. using namespace RcppParallel;
  62.  
  63. struct KernelCompute: public Worker {
  64.  mat& X;
  65.  mat& Center;
  66.  double sigma;
  67.  mat& output;
  68.  KernelCompute(mat& X, mat& Center, double sigma, mat& output) : X(X), Center(Center), sigma(sigma), output(output) {}
  69.  void operator()(std::size_t begin, std::size_t end) {
  70.    for (uword row_index = begin; row_index < end; row_index++)
  71.    {
  72.      for (uword col_index = 0; col_index < Center.n_rows; col_index++)
  73.        output(row_index, col_index) = exp(sum(square(X.row(row_index)
  74.          - Center.row(col_index))) / (-2.0 * sigma * sigma));
  75.    }
  76.  }
  77. };
  78.  
  79. // [[Rcpp::export]]
  80. NumericMatrix kernelMatrix_tbb(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
  81.  uword n = Xr.nrow(), b = Centerr.nrow();
  82.  mat X(Xr.begin(), n, Xr.ncol(), false),
  83.      Center(Centerr.begin(), b, Centerr.ncol(), false), KerX(n, b);
  84.  KernelCompute kernelCompute(X, Center, sigma, KerX);
  85.  parallelFor(0, X.n_rows, kernelCompute);
  86.  return wrap(KerX);
  87. }')
  88.  
  89. N = 3000
  90. p = 100
  91. b = 500
  92. X = matrix(rnorm(N*p), ncol = p)
  93. center = X[sample(1:N, b),]
  94. sigma = 50
  95. kernel_X = kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center)
  96. kernel_X_arma = kernelMatrix_Arma(X, center, sigma)
  97. kernel_X_openmp = kernelMatrix_openmp(X, center, sigma)
  98. kernel_X_tbb = kernelMatrix_tbb(X, center, sigma)
  99. ## test
  100. all.equal(kernel_X@.Data, kernel_X_arma)
  101. all.equal(kernel_X@.Data, kernel_X_openmp)
  102. all.equal(kernel_X@.Data, kernel_X_tbb)
  103. # TRUE
  104.  
  105. library(rbenchmark)
  106. benchmark(
  107.   arma = kernelMatrix_Arma(X, center, sigma),
  108.   openmp = kernelMatrix_openmp(X, center, sigma),
  109.   tbb = kernelMatrix_tbb(X, center, sigma),
  110.   kernlab = kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center),
  111.   columns=c("test", "replications","elapsed", "relative"),
  112.   replications=20, order="relative")
  113. ## p = 100
  114. #      test replications elapsed relative
  115. # 3     tbb           20    1.51    1.000
  116. # 2  openmp           20    1.75    1.159
  117. # 1    arma           20    1.97    1.305
  118. # 4 kernlab           20    3.25    2.152
  119.  
  120. ## p = 300
  121. #      test replications elapsed relative
  122. # 1    arma           20    2.20    1.000
  123. # 4 kernlab           20    3.62    1.645
  124. # 3     tbb           20   27.52   12.509
  125. # 2  openmp           20   28.48   12.945
Advertisement
Add Comment
Please, Sign In to add comment