celestialgod

merge with function aggregation summary

Oct 13th, 2015
291
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 11.80 KB | None | 0 0
  1. library(data.table)
  2. library(plyr)
  3. library(dplyr)
  4. library(tidyr)
  5. library(purrr)
  6. library(magrittr)
  7. library(Rcpp)
  8. library(RcppArmadillo)
  9.  
  10. run_dplyr_method = TRUE
  11. run_plyr_method = TRUE
  12.  
  13. w <- matrix(sample(2e6), 1e5) %>% data.table() %>%
  14.      setnames(1:20,sample(LETTERS,20)) %>% .[,SP:=seq_len(nrow(.))]
  15. x <- matrix(sample(1e6), 1e5) %>% data.table() %>%
  16.      setnames(1:10,sample(LETTERS,10)) %>% .[,SP:=seq_len(nrow(.))]
  17. y <- matrix(sample(1e5), 1e4) %>% data.table() %>%
  18.      setnames(1:10,sample(LETTERS,10)) %>% .[,SP:=seq_len(nrow(.))]
  19. z <- matrix(sample(4e5), 2e4) %>% data.table() %>%
  20.      setnames(1:20,sample(LETTERS,20)) %>% .[,SP:=seq_len(nrow(.))]
  21.  
  22. cppFunction('
  23. NumericVector addv(NumericVector x, NumericVector y) {
  24.  NumericVector out(x.size());
  25.  NumericVector::iterator x_it,y_it,out_it;
  26.  for (x_it = x.begin(), y_it=y.begin(), out_it = out.begin();
  27.       x_it != x.end(); ++x_it, ++y_it, ++out_it) {
  28.    if (ISNA(*x_it)) {
  29.      *out_it = *y_it;
  30.    } else if (ISNA(*y_it)) {
  31.      *out_it = *x_it;
  32.    } else {
  33.      *out_it = *x_it + *y_it;
  34.    }
  35.  }
  36.  return out;
  37. }')
  38.  
  39. outer_join2 <- function(df1,df2,byNames) {
  40.   tt = intersect(colnames(df1)[-match(byNames,colnames(df1))],
  41.                colnames(df2)[-match(byNames,colnames(df2))])
  42.   if (length(tt) > 0)
  43.   {
  44.     df <- merge(df2, df1 %>% select_(.dots = paste0('-', tt)), by=byNames, all=TRUE)
  45.     dt <- merge(df2 %>% select_(.dots = paste0('-', tt)),
  46.         df1 %>% select_(.dots = c(byNames, tt)),by=byNames,all=T) %>%
  47.           select_(.dots = tt)
  48.     for (j in colnames(dt)) {set(df,j=j,value=addv(df[[j]],dt[[j]]))}
  49.   } else
  50.   {
  51.     df = merge(df2, df1, by=byNames, all=TRUE)
  52.   }
  53.   return (df)
  54. }
  55.  
  56. t = proc.time()
  57. out1 = Reduce(function(df1, df2) outer_join2(df1, df2, byNames="SP"), list(w,x,y,z)) %>% tbl_dt(FALSE)
  58. proc.time() - t
  59. #   user  system elapsed
  60. #   0.34    0.07    0.40
  61.  
  62. sourceCpp(code = '
  63. #include <Rcpp.h>
  64. #include <string>
  65. #include <vector>
  66. using namespace Rcpp;
  67. using namespace std;
  68.  
  69. // [[Rcpp::export]]
  70. List aggregate_merge_cpp(DataFrame df1, DataFrame df2,
  71.  vector<string> names_merge, Function aggregate_f){
  72.  List outputList(names_merge.size());
  73.  for (size_t i = 0; i < names_merge.size(); i++)
  74.  {
  75.    switch ( TYPEOF(df1[names_merge[i]]) ) {
  76.      case REALSXP: {
  77.        NumericVector tmp1 = as<NumericVector>(df1[names_merge[i]]),
  78.          tmp2 = as<NumericVector>(df2[names_merge[i]]);
  79.        outputList[i] = aggregate_f(tmp1, tmp2);
  80.        break;
  81.      }
  82.      case INTSXP: {
  83.        IntegerVector tmp1 = as<IntegerVector>(df1[names_merge[i]]),
  84.          tmp2 = as<IntegerVector>(df2[names_merge[i]]);
  85.        outputList[i] = aggregate_f(tmp1, tmp2);
  86.        break;
  87.      }
  88.      default:
  89.        stop("unsupported data type");
  90.    }
  91.  }
  92.  outputList.attr("names") = names_merge;
  93.  return outputList;
  94. }');
  95.  
  96. aggregate_merge <- function(x, y, byNames, aggregate_f = magrittr:::add){
  97.   if (!"data.table" %in% class(x))
  98.     x %<>% tbl_dt(FALSE)
  99.   if (!"data.table" %in% class(y))
  100.     y %<>% tbl_dt(FALSE)
  101.  
  102.   aggregate_names = setdiff(intersect(names(x), names(y)), byNames)
  103.   if ( length(aggregate_names) > 0)
  104.   {
  105.     x_index = x %>% select_(.dots = byNames)
  106.     y_index = y %>% select_(.dots = byNames)
  107.     indecies = bind_rows(x_index, y_index)
  108.     combine_rows = which(duplicated(indecies))
  109.     indecies %<>% filter(1:nrow(.) %in% combine_rows)
  110.  
  111.     if (length(byNames) >= 2)
  112.     {
  113.       stop("This part is undone for merging byNames whose length is greater than 2.")
  114.     } else
  115.     {
  116.       x_index = match(indecies[[byNames]], x_index[[byNames]])
  117.       y_index = match(indecies[[byNames]], y_index[[byNames]])
  118.     }
  119.     aggregate_dt = aggregate_merge_cpp(
  120.       x %>% select_(.dots = aggregate_names) %>% filter(1:nrow(.) %in% x_index) %>% replace(is.na(.), as.integer(0)),
  121.       y %>% select_(.dots = aggregate_names) %>% filter(1:nrow(.) %in% y_index) %>% replace(is.na(.), as.integer(0)),
  122.       aggregate_names, aggregate_f) %>% as_data_frame %>%
  123.       bind_cols(x %>% select_(.dots = byNames) %>% filter(1:nrow(.) %in% x_index)) %>%
  124.       bind_rows(
  125.         x %>% select_(.dots = c(byNames, aggregate_names)) %>% filter(!(1:nrow(.) %in% x_index)),
  126.         y %>% select_(.dots = c(byNames, aggregate_names)) %>% filter(!(1:nrow(.) %in% y_index))) %>%
  127.       tbl_dt(FALSE) %>% arrange_(.dots = byNames) %>% select_(.dots = paste0("-", byNames))
  128.     out_dt = merge(x %>% select_(.dots = setdiff(names(.), aggregate_names)),
  129.           y %>% select_(.dots = setdiff(names(.), aggregate_names)),
  130.           by = byNames, all = TRUE) %>% arrange_(.dots = byNames) %>%
  131.        bind_cols(aggregate_dt) %>% tbl_dt(FALSE)
  132.   }
  133.   else
  134.     out_dt = merge(x, y, by = byNames, all = TRUE)
  135.   out_dt
  136. }
  137. t = proc.time()
  138. out2 = Reduce(function(df1, df2) aggregate_merge(df1, df2, byNames="SP"), list(w,x,y,z))
  139. proc.time() - t
  140. #   user  system elapsed
  141. #   0.56    0.03    0.59
  142.  
  143. if (run_dplyr_method)
  144. {
  145.   t = proc.time()
  146.   wide_table = rbind.fill(list(w, x, y, z)) %>% tbl_dt(FALSE)
  147.   sum_without_na = function(vec) ifelse(all(is.na(vec)), NA_integer_, sum(vec, na.rm = TRUE))
  148.   out3 = wide_table %>% group_by(SP) %>% summarise_each(funs(sum_without_na))
  149.   proc.time() - t
  150. }
  151. #   user  system elapsed
  152. #   5.77    0.03    5.80
  153.  
  154. t = proc.time()
  155. out4 = list(w, x, y, z) %>% llply(function(dt){
  156.   gather(dt, variable, values, -SP) %>% mutate(variable = as.character(variable))
  157.   }) %>% bind_rows %>% tbl_dt(FALSE) %>% group_by(SP, variable) %>%
  158.   summarise(values = sum(values)) %>% spread(variable, values)
  159. proc.time() - t
  160. #   user  system elapsed
  161. #   1.36    0.08    1.44
  162.  
  163. if (run_plyr_method)
  164. {
  165.   t = proc.time()
  166.   wide_table = rbind.fill(list(w, x, y, z)) %>% tbl_dt(FALSE)
  167.   out5 = ddply(wide_table, .(SP), function(mat) colSums(mat, na.rm = TRUE)) %>% tbl_dt(FALSE)
  168.   proc.time() - t
  169. #   user  system elapsed
  170. #  28.97    0.12   29.09
  171. }
  172.  
  173.  
  174. aggregate_merge <- function(..., by){
  175.   dfs = list(...)
  176.   overall_keys = dfs %>% map(~select_(., .dots = by)) %>% rbindlist %>% distinct
  177.   dfs %<>% map(~merge(., overall_keys, by = by, all = TRUE))
  178.   duplicated_cols = dfs %>% map(~names(.)) %>% do.call(c, .) %>%
  179.     .[duplicated(.)] %>% setdiff(by)
  180.   tmp = llply(dfs, function(dt) llply(1:length(duplicated_cols), function(i){
  181.     dt[[duplicated_cols[i]]]
  182.   })) %>% zip_n %>% map(~do.call(cbind, .))
  183.   tmp2 = tmp %>% map(~which(rowMeans(is.na(.)) == 1))
  184.   tmp %<>% map(~rowSums(., na.rm = TRUE)) %>%
  185.     set_names(duplicated_cols) %>% as_data_frame %>% tbl_dt(FALSE) %>%
  186.     bind_cols(overall_keys) %>% tbl_dt(FALSE)
  187.   for (i in 1:length(tmp2))
  188.   {
  189.     if (length(tmp2[[i]]) > 0)
  190.       set(tmp, i = tmp2[[i]], j = i, value = NA)
  191.   }
  192.  
  193.   bind_dfs_length = llply(dfs, function(dt) setdiff(names(dt), c(by, duplicated_cols))) %>% map_int(~length(.))
  194.   dfs[bind_dfs_length > 0] %>% map(~select_(., .dots = setdiff(names(.), c(by, duplicated_cols)))) %>%
  195.     bind_cols %>% tbl_dt(FALSE) %>% bind_cols(overall_keys) %>% tbl_dt(FALSE) %>%
  196.     merge(tmp, by = by)
  197. }
  198.  
  199. t = proc.time()
  200. out6 = aggregate_merge(w, x, y, z, by = "SP")
  201. proc.time() - t
  202. #   user  system elapsed
  203. #   0.36    0.02    0.38
  204.  
  205. sourceCpp(code = '
  206. #include <Rcpp.h>
  207. #include <string>
  208. #include <vector>
  209. using namespace Rcpp;
  210. using namespace std;
  211.  
  212. NumericVector add_narm(NumericVector x, NumericVector y) {
  213. NumericVector out(x.size());
  214. NumericVector::iterator x_it,y_it,out_it;
  215. for (x_it = x.begin(), y_it=y.begin(), out_it = out.begin();
  216.      x_it != x.end(); ++x_it, ++y_it, ++out_it) {
  217.   if (ISNA(*x_it)) {
  218.     *out_it = *y_it;
  219.   } else if (ISNA(*y_it)) {
  220.     *out_it = *x_it;
  221.   } else {
  222.     *out_it = *x_it + *y_it;
  223.   }
  224. }
  225. return out;
  226. }
  227.  
  228. // [[Rcpp::export]]
  229. DataFrame merge_all_cpp(List dfs){
  230. DataFrame out(dfs[0]);
  231. vector<string> df_names = out.names();
  232. for (int i = 1; i < dfs.length(); i++)
  233. {
  234.   DataFrame tmp(dfs[i]);
  235.   vector<string> tmp_names = tmp.names();
  236.   for (int j = 0; j < tmp.length(); j++)
  237.   {
  238.     if (find(df_names.begin(), df_names.end(), tmp_names[j]) != df_names.end())
  239.     {
  240.       out[tmp_names[j]] = add_narm(
  241.         as<NumericVector>(out[tmp_names[j]]),
  242.         as<NumericVector>(tmp[tmp_names[j]]));
  243.     }
  244.     else
  245.     {
  246.       out.push_back(tmp[tmp_names[j]], tmp_names[j]);
  247.       df_names.push_back(tmp_names[j]);
  248.     }
  249.   }
  250. }
  251. return out;
  252. }');
  253.  
  254. t = proc.time()
  255. dfs = list(w,x,y,z)
  256. by = "SP"
  257. overall_keys = list(w,x,y,z) %>% map(~select_(., .dots = by)) %>% rbindlist %>% distinct
  258. dfs %<>% map(~merge(., overall_keys, by = by, all = TRUE))
  259. ## not use purrr
  260. # overall_keys = dfs %>% lapply(function(dt) select_(x, .dots = by)) %>% rbindlist %>% distinct
  261. # dfs %<>% lapply(function(dt) merge(x, overall_keys, by = by, all = TRUE))
  262. f = function(var) var / length(dfs)
  263. out7 = merge_all_cpp(dfs) %>% tbl_dt(FALSE) %>% mutate_each_(funs(f), by)
  264. proc.time() - t
  265. #   user  system elapsed
  266. #   0.35    0.00    0.35
  267.  
  268. sourceCpp(code = '
  269. // [[Rcpp::depends(RcppArmadillo)]]
  270. #include <RcppArmadillo.h>
  271. #include <string>
  272. #include <vector>
  273. #include <algorithm>
  274. using namespace Rcpp;
  275. using namespace std;
  276. using namespace arma;
  277.  
  278. // [[Rcpp::export]]
  279. NumericVector add_narm(SEXP xs, SEXP ys){
  280.  NumericVector xr(xs);
  281.  NumericVector yr(ys);
  282.  colvec x(xr.begin(), xr.size(), false);
  283.  colvec y(yr.begin(), yr.size(), false);
  284.  uvec loc_na_x = find_nonfinite(x);
  285.  uvec loc_na_y = find_nonfinite(y);
  286.  colvec z = x + y;
  287.  vector<uword> loc_na_std;
  288.  set_intersection(loc_na_x.begin(), loc_na_x.end(),
  289.    loc_na_y.begin(), loc_na_y.end(),
  290.    back_inserter(loc_na_std));
  291.  uvec loc_na = conv_to<uvec>::from(loc_na_std);
  292.  x.elem(loc_na_x).zeros();
  293.  y.elem(loc_na_y).zeros();
  294.  z = x + y;
  295.  z.elem(loc_na).fill(NA_REAL);
  296.  return wrap(z);
  297. }
  298.  
  299. // [[Rcpp::export]]
  300. DataFrame merge_all_cpp_v2(List dfs){
  301.  DataFrame out(dfs[0]);
  302.  vector<string> df_names = out.names();
  303.  for (int i = 1; i < dfs.length(); i++)
  304.  {
  305.    DataFrame tmp(dfs[i]);
  306.    vector<string> tmp_names = tmp.names();
  307.    for (int j = 0; j < tmp.length(); j++)
  308.    {
  309.      if (find(df_names.begin(), df_names.end(), tmp_names[j]) != df_names.end())
  310.        out[tmp_names[j]] = add_narm(out[tmp_names[j]], tmp[tmp_names[j]]);
  311.      else
  312.      {
  313.        out.push_back(tmp[tmp_names[j]], tmp_names[j]);
  314.        df_names.push_back(tmp_names[j]);
  315.      }
  316.    }
  317.  }
  318.  return out;
  319. }')
  320.  
  321. t = proc.time()
  322. dfs = list(w,x,y,z)
  323. by = "SP"
  324. overall_keys = dfs %>% map(~select_(., .dots = by)) %>% rbindlist %>% distinct
  325. dfs %<>% map(~merge(., overall_keys, by = by, all = TRUE))
  326. ## not use purrr
  327. # overall_keys = dfs %>% lapply(function(dt) select_(dt, .dots = by)) %>% rbindlist %>% distinct
  328. # dfs %<>% lapply(function(dt) merge(dt, overall_keys, by = by, all = TRUE))
  329. f = function(var) var / length(dfs)
  330. out8 = merge_all_cpp_v2(dfs) %>% tbl_dt(FALSE) %>% mutate_each_(funs(f), by)
  331. proc.time() - t
  332. #   user  system elapsed
  333. #   0.11    0.00    0.11
  334.  
  335. out1 %<>% select_(.dots = names(out4)) %>% arrange(SP) %>% mutate_each(funs(as.integer))
  336. out2 %<>% select_(.dots = names(out4)) %>% arrange(SP)
  337. if (run_dplyr_method)
  338.   out3 %<>% select_(.dots = names(out4)) %>% arrange(SP)
  339. out4 %<>% arrange(SP)
  340. if (run_plyr_method)
  341.   out5 %<>% select_(.dots = names(out4)) %>% arrange(SP) %>% mutate_each_(funs(f), by) %>% mutate_each(funs(as.integer))
  342. out6 %<>% select_(.dots = names(out4)) %>% arrange(SP) %>% mutate_each(funs(as.integer))
  343. out7 %<>% select_(.dots = names(out4)) %>%
  344.   arrange_(.dots = by)  %>% mutate_each(funs(as.integer))
  345. out8 %<>% select_(.dots = names(out4)) %>%
  346.   arrange_(.dots = by)  %>% mutate_each(funs(as.integer))
  347.  
  348. all.equal(out4, out1) # TRUE
  349. all.equal(out4, out2) # TRUE
  350. if (run_dplyr_method)
  351.   all.equal(out4, out3) # TRUE
  352. if (run_plyr_method)
  353.   all.equal(out4, out5) # it has a error for NA + NA = 0
  354. all.equal(out4, out6) # TRUE
  355. all.equal(out4, out7) # TRUE
  356. all.equal(out4, out8) # TRUE
Advertisement
Add Comment
Please, Sign In to add comment