celestialgod

merge with function aggregation with Rcpp v2

Oct 13th, 2015
222
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 2.26 KB | None | 0 0
  1. library(data.table)
  2. library(plyr)
  3. library(dplyr)
  4. library(purrr)
  5. library(magrittr)
  6. library(Rcpp)
  7.  
  8. sourceCpp(code = '
  9. // [[Rcpp::depends(RcppArmadillo)]]
  10. #include <RcppArmadillo.h>
  11. // #include <Rcpp.h>
  12. #include <string>
  13. #include <vector>
  14. #include <algorithm>
  15. using namespace Rcpp;
  16. using namespace std;
  17. using namespace arma;
  18.  
  19. // [[Rcpp::export]]
  20. NumericVector add_narm(SEXP xs, SEXP ys){
  21.  NumericVector xr(xs);
  22.  NumericVector yr(ys);
  23.  colvec x(xr.begin(), xr.size(), false);
  24.  colvec y(yr.begin(), yr.size(), false);
  25.  uvec loc_na_x = find_nonfinite(x);
  26.  uvec loc_na_y = find_nonfinite(y);
  27.  colvec z = x + y;
  28.  vector<uword> loc_na_std;
  29.  set_intersection(loc_na_x.begin(), loc_na_x.end(),
  30.    loc_na_y.begin(), loc_na_y.end(),
  31.    back_inserter(loc_na_std));
  32.  uvec loc_na = conv_to<uvec>::from(loc_na_std);
  33.  x.elem(loc_na_x).zeros();
  34.  y.elem(loc_na_y).zeros();
  35.  z = x + y;
  36.  z.elem(loc_na).fill(NA_REAL);
  37.  return wrap(z);
  38. }
  39.  
  40. // [[Rcpp::export]]
  41. DataFrame merge_all_cpp(List dfs){
  42.  DataFrame out(dfs[0]);
  43.  vector<string> df_names = out.names();
  44.  for (int i = 1; i < dfs.length(); i++)
  45.  {
  46.    DataFrame tmp(dfs[i]);
  47.    vector<string> tmp_names = tmp.names();
  48.    for (int j = 0; j < tmp.length(); j++)
  49.    {
  50.      if (find(df_names.begin(), df_names.end(), tmp_names[j]) != df_names.end())
  51.        out[tmp_names[j]] = add_narm(out[tmp_names[j]], tmp[tmp_names[j]]);
  52.      else
  53.      {
  54.        out.push_back(tmp[tmp_names[j]], tmp_names[j]);
  55.        df_names.push_back(tmp_names[j]);
  56.      }
  57.    }
  58.  }
  59.  return out;
  60. }');
  61.  
  62. t = proc.time()
  63. dfs = list(w,x,y,z)
  64. by = "SP"
  65. overall_keys = dfs %>% map(~select_(., .dots = by)) %>% rbindlist %>% distinct
  66. dfs %<>% map(~merge(., overall_keys, by = by, all = TRUE))
  67. ## not use purrr
  68. # overall_keys = dfs %>% lapply(function(x) select_(x, .dots = by)) %>% rbindlist %>% distinct
  69. # dfs %<>% lapply(function(x) merge(x, overall_keys, by = by, all = TRUE))
  70. f = function(x) x / length(dfs)
  71. f = function(x) x / length(dfs)
  72. out8 = merge_all_cpp(dfs) %>% tbl_dt(FALSE) %>% mutate_each_(funs(f), by)
  73. proc.time() - t
  74. #   user  system elapsed
  75. #   0.04    0.03    0.12
  76. out8 %<>% select_(.dots = names(out4)) %>%
  77.   arrange_(.dots = by)  %>% mutate_each(funs(as.integer))
  78. all.equal(out4, out8) # TRUE
Advertisement
Add Comment
Please, Sign In to add comment