celestialgod

Fast Krusckal-Wallis Test

Oct 14th, 2017
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 3.60 KB | None | 0 0
  1. library(Rcpp)
  2. sourceCpp('
  3. // [[Rcpp::depends(RcppArmadillo)]]
  4. // [[Rcpp::plugins(cpp11)]]
  5. #define ARMA_USE_CXX11
  6. #include <RcppArmadillo.h>
  7. using namespace Rcpp;
  8.  
  9. arma::vec avgRank(const arma::vec &v) {
  10.  arma::uvec idx = arma::sort_index(v);
  11.  
  12.  arma::vec r(v.n_elem);
  13.  for (std::size_t n, i = 0; i < idx.size(); i += n) {
  14.    n = 1;
  15.    while (i + n < idx.size() && v[idx[i]] == v[idx[i+n]]) ++n;
  16.    for (std::size_t k = 0; k < n; ++k)
  17.      r[idx[i+k]] = i + (n + 1.0) / 2.0;
  18.  }
  19.  return r;
  20. }
  21.  
  22. // [[Rcpp::export]]
  23. List CppKruskalWallis(const arma::vec &x, const arma::Col<arma::sword> groupVec) {
  24.  // get rank of x
  25.  arma::vec ranks = avgRank(x);
  26.  
  27.  // get sum of ranks of each group
  28.  arma::Col<arma::sword> uniGrps = arma::unique(groupVec);
  29.  arma::vec rankSum = arma::zeros<arma::vec>(uniGrps.n_elem);
  30.  arma::uvec cnt = arma::zeros<arma::uvec>(uniGrps.n_elem), idx = arma::sort_index(groupVec);
  31.  std::size_t j = 0;
  32.  for (std::size_t n, i = 0; i < idx.size(); i += n) {
  33.    n = 1;
  34.    while (i + n < idx.size() && groupVec[idx[i]] == groupVec[idx[i+n]]) ++n;
  35.    for (std::size_t k = 0; k < n; ++k)
  36.      rankSum[j] += ranks[idx[i+k]];
  37.    cnt[j] = n;
  38.    ++j;
  39.  }
  40.  double stat1 = sum(square(rankSum) / cnt);
  41.  
  42.  // calculate correction factor of ties
  43.  idx = arma::sort_index(x);
  44.  double tiesFactor = 0.0;
  45.  for (std::size_t n, i = 0; i < idx.size(); i += n) {
  46.    n = 1;
  47.    while (i + n < idx.size() && x[idx[i]] == x[idx[i+n]]) ++n;
  48.    tiesFactor += std::pow(n, 3.0) - n;
  49.  }
  50.  
  51.  // calculate the statistic of K-W test
  52.  double l =  (double) x.n_elem, lp1 = l + 1.0,
  53.    numerator = 12.0 * stat1 / (l * lp1) - 3.0 * lp1,
  54.    denominator = 1.0 - tiesFactor / (pow(l, 3.0) - l);
  55.  NumericVector statistic = {numerator / denominator};
  56.  
  57.  // find the p-value
  58.  double df = (double) uniGrps.n_elem - 1.0;
  59.  NumericVector pValue = pchisq(statistic, df, false, false);
  60.  // return result
  61.  return List::create(_["statistic"] = statistic, _["df"] = df, _["p.value"] = pValue);
  62. }')
  63.  
  64.  
  65. library(fastmatch)
  66. fast_factor <- function(x, levels=NULL, labels=levels, na.last=NA) {
  67.   if (is.factor(x)) return(x)
  68.   if (is.null(levels)) levels <- sort(unique.default(x), na.last=na.last)
  69.   suppressWarnings(f <- fmatch(x, levels, nomatch=if (isTRUE(na.last)) length(levels) else NA_integer_))
  70.   levels(f) <- as.character(labels)
  71.   class(f) <- "factor"
  72.   f
  73. }
  74.  
  75. KruskalWallis <- function(x, g) {
  76.   stopifnot(length(x) == length(g))
  77.   idx <- complete.cases(x, g)
  78.   if (!is.integer(g)) {
  79.     if (is.factor(g)) {
  80.       g <- as.integer(g)
  81.     } else {
  82.       g <- fmatch(g, sort(unique(g)))
  83.     }
  84.   }
  85.   CppKruskalWallis(x[idx], g[idx])
  86. }
  87.  
  88. x <- sample(rnorm(1.6e4), 2e4, TRUE)
  89. g <- sample(1000, 2e4, TRUE)
  90. microbenchmark::microbenchmark(
  91.   KruskalWallis(x, g),
  92.   kruskal.test(x, g),
  93.   times = 100L
  94. )
  95. # Unit: milliseconds
  96. #                expr       min        lq      mean    median        uq       max neval
  97. #  KruskalWallis(x, g)  4.842012  4.880713  5.039397  4.932967  4.978144  7.673973   100
  98. #   kruskal.test(x, g) 64.904103 65.876902 66.632909 66.348846 67.016101 71.843507   100
  99.  
  100. x <- sample(rnorm(1.6e4), 2e4, TRUE)
  101. g <- sample(paste0("A", 1:1000), 2e4, TRUE)
  102. microbenchmark::microbenchmark(
  103.   KruskalWallis(x, g),
  104.   kruskal.test(x, fast_factor(g)),
  105.   times = 100L
  106. )
  107. # Unit: milliseconds
  108. #                             expr       min        lq      mean    median       uq      max neval
  109. #              KruskalWallis(x, g)  9.010593  9.119167  9.345936  9.226688  9.27081 15.42053   100
  110. #  kruskal.test(x, fast_factor(g)) 64.067736 65.450286 67.139090 66.286351 68.08648 92.51595   100
Advertisement
Add Comment
Please, Sign In to add comment