celestialgod

merge with function aggregation

Oct 12th, 2015
219
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 6.04 KB | None | 0 0
  1. library(data.table)
  2. library(plyr)
  3. library(dplyr)
  4. library(tidyr)
  5. library(purrr)
  6. library(magrittr)
  7. library(Rcpp)
  8.  
  9. run_plyr_method = FALSE
  10.  
  11. x <- matrix(sample(1e6), 1e5) %>% data.table() %>%
  12.      setnames(1:10,sample(LETTERS,10)) %>% .[,SP:=seq_len(nrow(.))]
  13. y <- matrix(sample(1e5), 1e4) %>% data.table() %>%
  14.      setnames(1:10,sample(LETTERS,10)) %>% .[,SP:=seq_len(nrow(.))]
  15. z <- matrix(sample(4e5), 2e4) %>% data.table() %>%
  16.      setnames(1:20,sample(LETTERS,20)) %>% .[,SP:=seq_len(nrow(.))]
  17.  
  18. cppFunction('
  19. NumericVector addv(NumericVector x, NumericVector y) {
  20.  NumericVector out(x.size());
  21.  NumericVector::iterator x_it,y_it,out_it;
  22.  for (x_it = x.begin(), y_it=y.begin(), out_it = out.begin();
  23.       x_it != x.end(); ++x_it, ++y_it, ++out_it) {
  24.    if (ISNA(*x_it)) {
  25.      *out_it = *y_it;
  26.    } else if (ISNA(*y_it)) {
  27.      *out_it = *x_it;
  28.    } else {
  29.      *out_it = *x_it + *y_it;
  30.    }
  31.  }
  32.  return out;
  33. }')
  34.  
  35. outer_join2 <- function (df1,df2,byNames) {
  36.   tt = intersect(colnames(df1)[-match(byNames,colnames(df1))],
  37.                colnames(df2)[-match(byNames,colnames(df2))])
  38.   if (length(tt) > 0)
  39.   {
  40.     df <- merge(df2, df1 %>% select_(.dots = paste0('-', tt)), by=byNames, all=TRUE)
  41.     dt <- merge(df2 %>% select_(.dots = paste0('-', tt)),
  42.         df1 %>% select_(.dots = c(byNames, tt)),by=byNames,all=T) %>%
  43.           select_(.dots = tt)
  44.     for (j in colnames(dt)) {set(df,j=j,value=addv(df[[j]],dt[[j]]))}
  45.   } else
  46.   {
  47.     df = merge(df2, df1, by=byNames, all=TRUE)
  48.   }
  49.   return (df)
  50. }
  51.  
  52. t = proc.time()
  53. out1 = Reduce(function(x, y) outer_join2(x, y, byNames="SP"), list(x,y,z)) %>% tbl_dt(FALSE)
  54. proc.time() - t
  55. #   user  system elapsed
  56. #   0.36    0.03    0.40
  57.  
  58. sourceCpp(code = '
  59. #include <Rcpp.h>
  60. #include <string>
  61. #include <vector>
  62. using namespace Rcpp;
  63. using namespace std;
  64.  
  65. // [[Rcpp::export]]
  66. List aggregate_merge_cpp(DataFrame df1, DataFrame df2,
  67.  vector<string> names_merge, Function aggregate_f){
  68.  List outputList(names_merge.size());
  69.  for (size_t i = 0; i < names_merge.size(); i++)
  70.  {
  71.    switch ( TYPEOF(df1[names_merge[i]]) ) {
  72.      case REALSXP: {
  73.        NumericVector tmp1 = as<NumericVector>(df1[names_merge[i]]),
  74.          tmp2 = as<NumericVector>(df2[names_merge[i]]);
  75.        outputList[i] = aggregate_f(tmp1, tmp2);
  76.        break;
  77.      }
  78.      case INTSXP: {
  79.        IntegerVector tmp1 = as<IntegerVector>(df1[names_merge[i]]),
  80.          tmp2 = as<IntegerVector>(df2[names_merge[i]]);
  81.        outputList[i] = aggregate_f(tmp1, tmp2);
  82.        break;
  83.      }
  84.      default:
  85.        stop("unsupported data type");
  86.    }
  87.  }
  88.  outputList.attr("names") = names_merge;
  89.  return outputList;
  90. }');
  91.  
  92. aggregate_merge <- function(x, y, byNames, aggregate_f = magrittr:::add){
  93.   if (!"data.table" %in% class(x))
  94.     x %<>% tbl_dt(FALSE)
  95.   if (!"data.table" %in% class(y))
  96.     y %<>% tbl_dt(FALSE)
  97.  
  98.   aggregate_names = setdiff(intersect(names(x), names(y)), byNames)
  99.   if ( length(aggregate_names) > 0)
  100.   {
  101.     x_index = x %>% select_(.dots = byNames)
  102.     y_index = y %>% select_(.dots = byNames)
  103.     indecies = bind_rows(x_index, y_index)
  104.     combine_rows = which(duplicated(indecies))
  105.     indecies %<>% filter(1:nrow(.) %in% combine_rows)
  106.  
  107.     if (length(byNames) >= 2)
  108.     {
  109.       stop("This part is undone for merging byNames whose length is greater than 2.")
  110.     } else
  111.     {
  112.       x_index = match(indecies[[byNames]], x_index[[byNames]])
  113.       y_index = match(indecies[[byNames]], y_index[[byNames]])
  114.     }
  115.     aggregate_dt = aggregate_merge_cpp(
  116.       x %>% select_(.dots = aggregate_names) %>% filter(1:nrow(.) %in% x_index) %>% replace(is.na(.), as.integer(0)),
  117.       y %>% select_(.dots = aggregate_names) %>% filter(1:nrow(.) %in% y_index) %>% replace(is.na(.), as.integer(0)),
  118.       aggregate_names, aggregate_f) %>% as_data_frame %>%
  119.       bind_cols(x %>% select_(.dots = byNames) %>% filter(1:nrow(.) %in% x_index)) %>%
  120.       bind_rows(
  121.         x %>% select_(.dots = c(byNames, aggregate_names)) %>% filter(!(1:nrow(.) %in% x_index)),
  122.         y %>% select_(.dots = c(byNames, aggregate_names)) %>% filter(!(1:nrow(.) %in% y_index))) %>%
  123.       tbl_dt(FALSE) %>% arrange_(.dots = byNames) %>% select_(.dots = paste0("-", byNames))
  124.     out_dt = merge(x %>% select_(.dots = setdiff(names(.), aggregate_names)),
  125.           y %>% select_(.dots = setdiff(names(.), aggregate_names)),
  126.           by = byNames, all = TRUE) %>% arrange_(.dots = byNames) %>%
  127.        bind_cols(aggregate_dt) %>% tbl_dt(FALSE)
  128.   }
  129.   else
  130.     out_dt = merge(x, y, by = byNames, all = TRUE)
  131.   out_dt
  132. }
  133. t = proc.time()
  134. out2 = Reduce(function(df1, df2) aggregate_merge(df1, df2, byNames="SP"), list(x,y,z))
  135. proc.time() - t
  136. #   user  system elapsed
  137. #   0.31    0.00    0.32
  138.  
  139. t = proc.time()
  140. wide_table = rbind.fill(list(x, y, z)) %>% tbl_dt(FALSE)
  141. sum_without_na = function(x) ifelse(all(is.na(x)), NA_integer_, sum(x, na.rm = TRUE))
  142. out3 = wide_table %>% group_by(SP) %>% summarise_each(funs(sum_without_na))
  143. proc.time() - t
  144. #   user  system elapsed
  145. #   8.61    0.00    8.66
  146.  
  147. t = proc.time()
  148. out4 = list(x, y, z) %>% llply(function(x){
  149.   gather(x, variable, values, -SP) %>% mutate(variable = as.character(variable))
  150.   }) %>% bind_rows %>% tbl_dt(FALSE) %>% group_by(SP, variable) %>%
  151.   summarise(values = sum(values)) %>% spread(variable, values)
  152. proc.time() - t
  153. #   user  system elapsed
  154. #   1.06    0.12    1.19
  155.  
  156. if (run_plyr_method)
  157. {
  158.   t = proc.time()
  159.   wide_table = rbind.fill(list(x, y, z)) %>% tbl_dt(FALSE)
  160.   out5 = ddply(wide_table, .(SP), function(x) colSums(x, na.rm = TRUE)) %>% tbl_dt(FALSE)
  161.   proc.time() - t
  162. #   user  system elapsed
  163. #  49.65    0.05   50.05
  164. }
  165.  
  166. out1 %<>% select_(.dots = names(out4)) %>% arrange(SP) %>% mutate_each(funs(as.integer))
  167. out2 %<>% select_(.dots = names(out4)) %>% arrange(SP)
  168. out3 %<>% select_(.dots = names(out4)) %>% arrange(SP)
  169. out4 %<>% arrange(SP)
  170. if (run_plyr_method)
  171.   out5 %<>% arrange(SP)
  172.  
  173. all.equal(out4, out1) # TRUE
  174. all.equal(out4, out2) # TRUE
  175. all.equal(out4, out3) # TRUE
  176. if (run_plyr_method)
  177.   all.equal(out4, out5) # TRUE
Advertisement
Add Comment
Please, Sign In to add comment