Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- library(Rcpp)
- sourceCpp('
- // [[Rcpp::depends(RcppArmadillo)]]
- // [[Rcpp::plugins(cpp11)]]
- #define ARMA_USE_CXX11
- #include <RcppArmadillo.h>
- using namespace Rcpp;
- arma::vec avgRank(const arma::vec &v) {
- arma::uvec idx = arma::sort_index(v);
- arma::vec r(v.n_elem);
- for (std::size_t n, i = 0; i < idx.size(); i += n) {
- n = 1;
- while (i + n < idx.size() && v[idx[i]] == v[idx[i+n]]) ++n;
- for (std::size_t k = 0; k < n; ++k)
- r[idx[i+k]] = i + (n + 1.0) / 2.0;
- }
- return r;
- }
- // [[Rcpp::export]]
- List CppKruskalWallis(const arma::vec &x, const arma::Col<arma::sword> groupVec) {
- // get rank of x
- arma::vec ranks = avgRank(x);
- // get sum of ranks of each group
- arma::Col<arma::sword> uniGrps = arma::unique(groupVec);
- arma::vec rankSum = arma::zeros<arma::vec>(uniGrps.n_elem);
- arma::uvec cnt = arma::zeros<arma::uvec>(uniGrps.n_elem), idx = arma::sort_index(groupVec);
- std::size_t j = 0;
- for (std::size_t n, i = 0; i < idx.size(); i += n) {
- n = 1;
- while (i + n < idx.size() && groupVec[idx[i]] == groupVec[idx[i+n]]) ++n;
- for (std::size_t k = 0; k < n; ++k)
- rankSum[j] += ranks[idx[i+k]];
- cnt[j] = n;
- ++j;
- }
- double stat1 = sum(square(rankSum) / cnt);
- // calculate correction factor of ties
- idx = arma::sort_index(x);
- double tiesFactor = 0.0;
- for (std::size_t n, i = 0; i < idx.size(); i += n) {
- n = 1;
- while (i + n < idx.size() && x[idx[i]] == x[idx[i+n]]) ++n;
- tiesFactor += std::pow(n, 3.0) - n;
- }
- // calculate the statistic of K-W test
- double l = (double) x.n_elem, lp1 = l + 1.0,
- numerator = 12.0 * stat1 / (l * lp1) - 3.0 * lp1,
- denominator = 1.0 - tiesFactor / (pow(l, 3.0) - l);
- NumericVector statistic = {numerator / denominator};
- // find the p-value
- double df = (double) uniGrps.n_elem - 1.0;
- NumericVector pValue = pchisq(statistic, df, false, false);
- // return result
- return List::create(_["statistic"] = statistic, _["df"] = df, _["p.value"] = pValue);
- }')
- library(fastmatch)
- fast_factor <- function(x, levels=NULL, labels=levels, na.last=NA) {
- if (is.factor(x)) return(x)
- if (is.null(levels)) levels <- sort(unique.default(x), na.last=na.last)
- suppressWarnings(f <- fmatch(x, levels, nomatch=if (isTRUE(na.last)) length(levels) else NA_integer_))
- levels(f) <- as.character(labels)
- class(f) <- "factor"
- f
- }
- KruskalWallis <- function(x, g) {
- stopifnot(length(x) == length(g))
- idx <- complete.cases(x, g)
- if (!is.integer(g)) {
- if (is.factor(g)) {
- g <- as.integer(g)
- } else {
- g <- fmatch(g, sort(unique(g)))
- }
- }
- CppKruskalWallis(x[idx], g[idx])
- }
- x <- sample(rnorm(1.6e4), 2e4, TRUE)
- g <- sample(1000, 2e4, TRUE)
- microbenchmark::microbenchmark(
- KruskalWallis(x, g),
- kruskal.test(x, g),
- times = 100L
- )
- # Unit: milliseconds
- # expr min lq mean median uq max neval
- # KruskalWallis(x, g) 4.842012 4.880713 5.039397 4.932967 4.978144 7.673973 100
- # kruskal.test(x, g) 64.904103 65.876902 66.632909 66.348846 67.016101 71.843507 100
- x <- sample(rnorm(1.6e4), 2e4, TRUE)
- g <- sample(paste0("A", 1:1000), 2e4, TRUE)
- microbenchmark::microbenchmark(
- KruskalWallis(x, g),
- kruskal.test(x, fast_factor(g)),
- times = 100L
- )
- # Unit: milliseconds
- # expr min lq mean median uq max neval
- # KruskalWallis(x, g) 9.010593 9.119167 9.345936 9.226688 9.27081 15.42053 100
- # kruskal.test(x, fast_factor(g)) 64.067736 65.450286 67.139090 66.286351 68.08648 92.51595 100
Advertisement
Add Comment
Please, Sign In to add comment