celestialgod

LOOCV with Rcpp (RcppLOOCV.cpp)

Dec 31st, 2016
181
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.93 KB | None | 0 0
  1. // [[Rcpp::depends(RcppArmadillo, RcppEigen, RcppParallel)]]
  2. // [[Rcpp::plugins(openmp)]]
  3. #include <RcppArmadillo.h>
  4. #include <RcppEigen.h>
  5. #include <omp.h>
  6. #include <RcppParallel.h>
  7. using arma::mat;
  8. using arma::vec;
  9. using arma::uvec;
  10. using arma::uword;
  11. using Eigen::Map;
  12. using Eigen::MatrixXd;
  13. using Eigen::VectorXd;
  14.  
  15. struct arma_MSE_Compute: public RcppParallel::Worker {
  16.   const mat& X;
  17.   const vec& y;
  18.   const uvec& index;
  19.   double mse;
  20.  
  21.   arma_MSE_Compute(const mat& X, const vec& y, const uvec& index):
  22.     X(X), y(y), index(index), mse(0.0) {}
  23.  
  24.   arma_MSE_Compute(const arma_MSE_Compute& arma_MSE_worker, RcppParallel::Split):
  25.     X(arma_MSE_worker.X), y(arma_MSE_worker.y), index(arma_MSE_worker.index), mse(0.0) {}
  26.  
  27.   void operator()(std::size_t begin, std::size_t end) {
  28.     for (uword i = begin; i < end; ++i) {
  29.       uvec idx = arma::find(index != i);
  30.       mse += pow(y(i) - arma::dot(X.row(i), arma::solve(X.rows(idx), y.elem(idx))), 2.0);
  31.     }
  32.   }
  33.  
  34.   void join(const arma_MSE_Compute& rhs) {
  35.     mse += rhs.mse;
  36.   }
  37. };
  38.  
  39. // [[Rcpp::export]]
  40. double arma_fastLOOCV1(const arma::vec& y, const arma::mat& X) {
  41.   mat X_with_ones = arma::join_rows(arma::ones<vec>(X.n_rows), X);
  42.   uvec index = arma::linspace<uvec>(0, y.n_elem - 1, y.n_elem);
  43.  
  44.   arma_MSE_Compute mseResults(X_with_ones, y, index);
  45.   RcppParallel::parallelReduce(0, y.n_elem, mseResults);
  46.   return mseResults.mse / y.n_elem;
  47. }
  48.  
  49. // [[Rcpp::export]]
  50. double arma_fastLOOCV2(const arma::vec& y, const arma::mat& X) {
  51.   mat X_with_ones = arma::join_rows(arma::ones<vec>(X.n_rows), X);
  52.   uvec index = arma::linspace<uvec>(0, y.n_elem - 1, y.n_elem);
  53.   vec mse = arma::zeros<vec>(y.n_elem);
  54.  
  55.   uword i = 0;
  56.   #pragma omp parallel for private(i)
  57.   for (i = 0; i < y.n_elem; ++i) {
  58.     uvec idx = arma::find(index != i);
  59.     mse(i) = pow(y(i) - arma::dot(X_with_ones.row(i), arma::solve(X_with_ones.rows(idx), y.elem(idx))), 2.0);
  60.   }
  61.   return mean(mse);
  62. }
  63.  
  64. struct eigen_MSE_Compute: public RcppParallel::Worker {
  65.   MatrixXd X;
  66.   VectorXd y;
  67.   double mse;
  68.  
  69.   eigen_MSE_Compute(MatrixXd X, VectorXd y):
  70.     X(X), y(y), mse(0.0) {}
  71.  
  72.   eigen_MSE_Compute(const eigen_MSE_Compute& eigen_MSE_worker, RcppParallel::Split):
  73.     X(eigen_MSE_worker.X), y(eigen_MSE_worker.y), mse(0.0) {}
  74.  
  75.   void operator()(std::size_t begin, std::size_t end) {
  76.     for (unsigned int i = begin; i < end; ++i) {
  77.       MatrixXd tmpX(X.rows() - 1, X.cols());
  78.       VectorXd tmpY(y.size() - 1);
  79.       if (i == 0) {
  80.         tmpX = X.bottomRows(X.rows() - 1);
  81.         tmpY = y.tail(y.size() - 1);
  82.       } else if (i == X.rows() - 1) {
  83.         tmpX = X.topRows(X.rows() - 1);
  84.         tmpY = y.head(y.size() - 1);
  85.       } else {
  86.         tmpX << X.topRows(i),
  87.                 X.bottomRows(X.rows() - i - 1);
  88.         tmpY << y.head(i),
  89.                 y.tail(y.size() - i - 1);
  90.       }
  91.       mse += pow(y(i) - (tmpX.colPivHouseholderQr().solve(tmpY)).dot(X.row(i)), 2.0);
  92.     }
  93.   }
  94.  
  95.   void join(const eigen_MSE_Compute& rhs) {
  96.     mse += rhs.mse;
  97.   }
  98. };
  99.  
  100. // [[Rcpp::export]]
  101. double eigen_fastLOOCV1(const Eigen::Map<VectorXd>& y,
  102.                         const Eigen::Map<MatrixXd>& X) {
  103.   MatrixXd X_with_ones(X.rows(), X.cols() + 1);
  104.   X_with_ones << MatrixXd::Ones(X.rows(), 1), X;
  105.  
  106.   eigen_MSE_Compute mseResults(X_with_ones, y);
  107.   RcppParallel::parallelReduce(0, y.size(), mseResults);
  108.   return mseResults.mse / y.size();
  109. }
  110.  
  111. struct eigen_MSE_Compute2: public RcppParallel::Worker {
  112.   const MatrixXd& X;
  113.   const VectorXd& y;
  114.   VectorXd& mse;
  115.  
  116.   eigen_MSE_Compute2(const MatrixXd& X, const VectorXd& y, VectorXd& mse):
  117.     X(X), y(y), mse(mse) {}
  118.  
  119.   void operator()(std::size_t begin, std::size_t end) {
  120.     for (unsigned int i = begin; i < end; ++i) {
  121.       MatrixXd tmpX(X.rows() - 1, X.cols());
  122.       VectorXd tmpY(y.size() - 1);
  123.       if (i == 0) {
  124.         tmpX = X.bottomRows(X.rows() - 1);
  125.         tmpY = y.tail(y.size() - 1);
  126.       } else if (i == X.rows() - 1) {
  127.         tmpX = X.topRows(X.rows() - 1);
  128.         tmpY = y.head(y.size() - 1);
  129.       } else {
  130.         tmpX << X.topRows(i),
  131.                 X.bottomRows(X.rows() - i - 1);
  132.         tmpY << y.head(i),
  133.                 y.tail(y.size() - i - 1);
  134.       }
  135.       mse(i) = pow(y(i) - (tmpX.colPivHouseholderQr().solve(tmpY)).dot(X.row(i)), 2.0);
  136.     }
  137.   }
  138. };
  139.  
  140. // [[Rcpp::export]]
  141. double eigen_fastLOOCV2(Rcpp::NumericVector yin,
  142.                         const Eigen::Map<MatrixXd>& X) {
  143.   // Eigen::Map<VectorXd>& object in RcppParall::Worker would cause crash
  144.   // but we can input Rcpp::NumericVector, and use VectorXd::Map to convert to VectorXd
  145.   VectorXd y = VectorXd::Map(yin.begin(), yin.size());
  146.   MatrixXd X_with_ones(X.rows(), X.cols() + 1);
  147.   X_with_ones << MatrixXd::Ones(X.rows(), 1), X;
  148.   VectorXd mse = VectorXd::Zero(y.size());
  149.  
  150.  
  151.   eigen_MSE_Compute2 mseResults(X_with_ones, y, mse);
  152.   RcppParallel::parallelFor(0, y.size(), mseResults);
  153.   return mse.mean();
  154. }
  155.  
  156. // [[Rcpp::export]]
  157. double eigen_fastLOOCV3(const Eigen::Map<VectorXd>& y,
  158.                         const Eigen::Map<MatrixXd>& X) {
  159.   MatrixXd X_with_ones(X.rows(), X.cols() + 1);
  160.   X_with_ones << MatrixXd::Ones(X.rows(), 1), X;
  161.  
  162.   VectorXd mse = VectorXd::Zero(y.size());
  163.   unsigned int i = 0;
  164.  
  165.   #pragma omp parallel for private(i)
  166.   for (i = 0; i < X.rows(); ++i) {
  167.     MatrixXd tmpX(X.rows() - 1, X_with_ones.cols());
  168.     VectorXd tmpY(y.size() - 1);
  169.     if (i == 0) {
  170.       tmpX = X_with_ones.bottomRows(X_with_ones.rows() - 1);
  171.       tmpY = y.tail(y.size() - 1);
  172.     } else if (i == X_with_ones.rows() - 1) {
  173.       tmpX = X_with_ones.topRows(X_with_ones.rows() - 1);
  174.       tmpY = y.head(y.size() - 1);
  175.     } else {
  176.       tmpX << X_with_ones.topRows(i),
  177.               X_with_ones.bottomRows(X_with_ones.rows() - i - 1);
  178.       tmpY << y.head(i),
  179.               y.tail(y.size() - i - 1);
  180.     }
  181.     mse(i) = pow(y(i) - (tmpX.colPivHouseholderQr().solve(tmpY)).dot(X_with_ones.row(i)), 2.0);
  182.   }
  183.   return mse.mean();
  184. }
  185.  
  186. // [[Rcpp::export]]
  187. double eigen_fastLOOCV4(const Eigen::Map<VectorXd>& y,
  188.                         const Eigen::Map<MatrixXd>& X) {
  189.   MatrixXd X_with_ones(X.rows(), X.cols() + 1);
  190.   X_with_ones << MatrixXd::Ones(X.rows(), 1), X;
  191.  
  192.   double mse = 0.0;
  193.   for (unsigned int i = 0; i < X.rows(); ++i) {
  194.     MatrixXd tmpX(X.rows() - 1, X_with_ones.cols());
  195.     VectorXd tmpY(y.size() - 1);
  196.     if (i == 0) {
  197.       tmpX = X_with_ones.bottomRows(X_with_ones.rows() - 1);
  198.       tmpY = y.tail(y.size() - 1);
  199.     } else if (i == X_with_ones.rows() - 1) {
  200.       tmpX = X_with_ones.topRows(X_with_ones.rows() - 1);
  201.       tmpY = y.head(y.size() - 1);
  202.     } else {
  203.       tmpX << X_with_ones.topRows(i),
  204.               X_with_ones.bottomRows(X_with_ones.rows() - i - 1);
  205.       tmpY << y.head(i),
  206.               y.tail(y.size() - i - 1);
  207.     }
  208.     mse += pow(y(i) - (tmpX.colPivHouseholderQr().solve(tmpY)).dot(X_with_ones.row(i)), 2.0);
  209.   }
  210.   return mse / y.size();
  211. }
Advertisement
Add Comment
Please, Sign In to add comment