celestialgod

multiple objective pareto front (C++ part)

Nov 7th, 2016
225
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.22 KB | None | 0 0
  1. // [[Rcpp::depends(RcppArmadillo, RcppParallel)]]
  2. #include <RcppArmadillo.h>
  3. using namespace arma;
  4. #include <RcppParallel.h>
  5. using namespace RcppParallel;
  6.  
  7. void chk_mat(const mat& x, const std::string& varName, const std::string& type){
  8.   if (!is_finite(x))
  9.     Rcpp::stop(varName + " must be numerical.\n");
  10. }
  11.  
  12. int compare_vec(const rowvec& mat_row, const rowvec& pivot_row){
  13.   int v = 0;
  14.   uvec v1 = find(mat_row < pivot_row, 1),
  15.   v2 = find(mat_row > pivot_row, 1);
  16.   if (v1.is_empty() && !v2.is_empty())
  17.     v = -1;
  18.   if (!v1.is_empty() && v2.is_empty())
  19.     v = 1;
  20.   if (!v1.is_empty() && !v2.is_empty())
  21.     v = (v1(0) < v2(0)) ? 1 : -1;
  22.   return v;
  23. }
  24.  
  25. // quick sort algorithm to sort rows in a matrix
  26. void sortrows(mat& M, uvec& idx, const int& left, const int& right){
  27.   if (left < right) {
  28.     int i = left, j = right;
  29.     // find a middle location
  30.     uword mid_loc = (uword) (left+right)/2, pivot_loc = mid_loc;
  31.     // use median of central 5 rows to get a pivot row to split rows
  32.     if (right - left > 5) {
  33.       uvec sortIndex = stable_sort_index(M.col(0).subvec(mid_loc-2, mid_loc+2));
  34.       pivot_loc = as_scalar(find(sortIndex == 2)) + mid_loc - 1;
  35.     }
  36.     // get pivot row
  37.     rowvec pivot_row = M.row(pivot_loc);
  38.     // use pivot row to split matrix into two parts
  39.     while (i <= j) {
  40.       // move i to right if left row > pivot low
  41.       while (compare_vec(M.row( (uword) i), pivot_row) == 1)
  42.         ++i;
  43.       // move j to left if right row < pivot low
  44.       while (compare_vec(M.row( (uword) j), pivot_row) == -1)
  45.         --j;
  46.       // exchange two rows
  47.       if (i <= j) {
  48.         M.swap_rows((uword) i, (uword) j);
  49.         idx.swap_rows((uword) i, (uword) j);
  50.         ++i;
  51.         --j;
  52.       }
  53.     }
  54.     // sort right part
  55.     if (j > 0)
  56.       sortrows(M, idx, left, j);
  57.     // sort left part
  58.     if (i < (int) M.n_rows - 1)
  59.       sortrows(M, idx, i, right);
  60.   }
  61. }
  62.  
  63. struct parallelComp: public Worker {
  64.   const mat& x;
  65.   rowvec y;
  66.   uvec& out;
  67.   parallelComp(const mat& x, rowvec y, uvec& out):
  68.     x(x), y(y), out(out) {}
  69.   void operator()(std::size_t begin, std::size_t end)
  70.   {
  71.     for (uword j = begin; j < end; j++)
  72.       out(j) = any(x.row(j) >= y);
  73.   }
  74. };
  75.  
  76. // [[Rcpp::export]]
  77. Rcpp::IntegerVector MPF3(arma::mat x, bool parallel = true){
  78.   chk_mat(x, "x", "double");
  79.   uvec order_i = linspace<uvec>(1, x.n_rows, x.n_rows);
  80.   sortrows(x, order_i, 0, x.n_rows - 1);
  81.  
  82.   uvec uniIdx = join_cols(ones<uvec>(1), any(x.rows(0, x.n_rows-2) != x.rows(1, x.n_rows-1), 1));
  83.   x = x.rows(find(uniIdx));
  84.   order_i = order_i(find(uniIdx));
  85.  
  86.   uword i = 0, tmp;
  87.   uvec idx, tmpIdx;
  88.   while (true) {
  89.     tmp = order_i(i);
  90.     idx.zeros(x.n_rows);
  91.     if (parallel) {
  92.       parallelComp paraComp(x, x.row(i), idx);
  93.       parallelFor(0, x.n_rows, paraComp);
  94.     } else {
  95.       for (uword j = 0; j < x.n_rows; j++)
  96.         idx(j) = any(x.row(j) >= x.row(i));
  97.     }
  98.     tmpIdx = find(idx == 1);
  99.     order_i = order_i.elem(tmpIdx);
  100.     x = x.rows(tmpIdx);
  101.     i = as_scalar(find(order_i == tmp, 1, "first")) + 1;
  102.     if (i >= order_i.n_elem)
  103.       break;
  104.   }
  105.   Rcpp::IntegerVector outVec = Rcpp::wrap(order_i);
  106.   outVec.attr("dim") = R_NilValue;
  107.   return outVec;
  108. }
Advertisement
Add Comment
Please, Sign In to add comment