Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // [[Rcpp::depends(RcppArmadillo, RcppParallel)]]
- #include <RcppArmadillo.h>
- using namespace arma;
- #include <RcppParallel.h>
- using namespace RcppParallel;
- void chk_mat(const mat& x, const std::string& varName, const std::string& type){
- if (!is_finite(x))
- Rcpp::stop(varName + " must be numerical.\n");
- }
- int compare_vec(const rowvec& mat_row, const rowvec& pivot_row){
- int v = 0;
- uvec v1 = find(mat_row < pivot_row, 1),
- v2 = find(mat_row > pivot_row, 1);
- if (v1.is_empty() && !v2.is_empty())
- v = -1;
- if (!v1.is_empty() && v2.is_empty())
- v = 1;
- if (!v1.is_empty() && !v2.is_empty())
- v = (v1(0) < v2(0)) ? 1 : -1;
- return v;
- }
- // quick sort algorithm to sort rows in a matrix
- void sortrows(mat& M, uvec& idx, const int& left, const int& right){
- if (left < right) {
- int i = left, j = right;
- // find a middle location
- uword mid_loc = (uword) (left+right)/2, pivot_loc = mid_loc;
- // use median of central 5 rows to get a pivot row to split rows
- if (right - left > 5) {
- uvec sortIndex = stable_sort_index(M.col(0).subvec(mid_loc-2, mid_loc+2));
- pivot_loc = as_scalar(find(sortIndex == 2)) + mid_loc - 1;
- }
- // get pivot row
- rowvec pivot_row = M.row(pivot_loc);
- // use pivot row to split matrix into two parts
- while (i <= j) {
- // move i to right if left row > pivot low
- while (compare_vec(M.row( (uword) i), pivot_row) == 1)
- ++i;
- // move j to left if right row < pivot low
- while (compare_vec(M.row( (uword) j), pivot_row) == -1)
- --j;
- // exchange two rows
- if (i <= j) {
- M.swap_rows((uword) i, (uword) j);
- idx.swap_rows((uword) i, (uword) j);
- ++i;
- --j;
- }
- }
- // sort right part
- if (j > 0)
- sortrows(M, idx, left, j);
- // sort left part
- if (i < (int) M.n_rows - 1)
- sortrows(M, idx, i, right);
- }
- }
- struct parallelComp: public Worker {
- const mat& x;
- rowvec y;
- uvec& out;
- parallelComp(const mat& x, rowvec y, uvec& out):
- x(x), y(y), out(out) {}
- void operator()(std::size_t begin, std::size_t end)
- {
- for (uword j = begin; j < end; j++)
- out(j) = any(x.row(j) >= y);
- }
- };
- // [[Rcpp::export]]
- Rcpp::IntegerVector MPF3(arma::mat x, bool parallel = true){
- chk_mat(x, "x", "double");
- uvec order_i = linspace<uvec>(1, x.n_rows, x.n_rows);
- sortrows(x, order_i, 0, x.n_rows - 1);
- uvec uniIdx = join_cols(ones<uvec>(1), any(x.rows(0, x.n_rows-2) != x.rows(1, x.n_rows-1), 1));
- x = x.rows(find(uniIdx));
- order_i = order_i(find(uniIdx));
- uword i = 0, tmp;
- uvec idx, tmpIdx;
- while (true) {
- tmp = order_i(i);
- idx.zeros(x.n_rows);
- if (parallel) {
- parallelComp paraComp(x, x.row(i), idx);
- parallelFor(0, x.n_rows, paraComp);
- } else {
- for (uword j = 0; j < x.n_rows; j++)
- idx(j) = any(x.row(j) >= x.row(i));
- }
- tmpIdx = find(idx == 1);
- order_i = order_i.elem(tmpIdx);
- x = x.rows(tmpIdx);
- i = as_scalar(find(order_i == tmp, 1, "first")) + 1;
- if (i >= order_i.n_elem)
- break;
- }
- Rcpp::IntegerVector outVec = Rcpp::wrap(order_i);
- outVec.attr("dim") = R_NilValue;
- return outVec;
- }
Advertisement
Add Comment
Please, Sign In to add comment