SHOW:
|
|
- or go back to the newest paste.
| 1 | library(kernlab) | |
| 2 | library(Rcpp) | |
| 3 | - | p = 5 |
| 3 | + | |
| 4 | - | b = 1000 |
| 4 | + | |
| 5 | p = 1000 | |
| 6 | b = 3000 | |
| 7 | - | sigma = 1 |
| 7 | + | |
| 8 | center = X[sample(1:N, b),] | |
| 9 | sigma = 3 | |
| 10 | ## For windows user | |
| 11 | # library(inline) | |
| 12 | - | library(inline) |
| 12 | + | # settings <- getPlugin("Rcpp")
|
| 13 | - | settings <- getPlugin("Rcpp")
|
| 13 | + | # settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
|
| 14 | - | settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
|
| 14 | + | # settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
|
| 15 | - | settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
|
| 15 | + | # do.call(Sys.setenv, settings$env) |
| 16 | - | do.call(Sys.setenv, settings$env) |
| 16 | + | |
| 17 | // [[Rcpp::depends(RcppArmadillo)]] | |
| 18 | #include <RcppArmadillo.h> | |
| 19 | using namespace Rcpp; | |
| 20 | - | #include <omp.h> |
| 20 | + | |
| 21 | ||
| 22 | // [[Rcpp::export]] | |
| 23 | NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
| |
| 24 | uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index; | |
| 25 | mat X(Xr.begin(), n, Xr.ncol(), false), Center(Centerr.begin(), b, Centerr.ncol(), false), KerX(X*Center.t()); | |
| 26 | - | omp_set_num_threads(omp_get_max_threads()); |
| 26 | + | colvec X_sq = sum(square(X), 1) / 2; |
| 27 | - | uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index; |
| 27 | + | rowvec Center_sq = (sum(square(Center), 1)).t() / 2; |
| 28 | - | mat X(Xr.begin(), n, Xr.ncol(), false); |
| 28 | + | KerX.each_row() -= Center_sq; |
| 29 | - | mat Center(Centerr.begin(), b, Centerr.ncol(), false); |
| 29 | + | KerX.each_col() -= X_sq; |
| 30 | - | mat KerX(n, b); |
| 30 | + | KerX *= 1 / (sigma * sigma); |
| 31 | - | #pragma omp parallel private(row_index, col_index) |
| 31 | + | KerX = exp(KerX); |
| 32 | - | for (row_index = 0; row_index < n; row_index++) |
| 32 | + | return wrap(KerX); |
| 33 | - | {
|
| 33 | + | |
| 34 | - | #pragma omp for nowait |
| 34 | + | |
| 35 | - | for (col_index = 0; col_index < b; col_index++) |
| 35 | + | t1 = Sys.time() |
| 36 | - | {
|
| 36 | + | kernel_X_cpp = kernelMatrix_cpp2(X, center, sigma) |
| 37 | - | KerX(row_index, col_index) = exp(sum(square(X.row(row_index) - Center.row(col_index))) / (-2.0 * sigma * sigma)); |
| 37 | + | Sys.time() - t1 |
| 38 | - | } |
| 38 | + | t1 = Sys.time() |
| 39 | - | } |
| 39 | + | |
| 40 | - | return wrap(KerX); |
| 40 | + | Sys.time() - t1 |
| 41 | all.equal([email protected], kernel_X_cpp) | |
| 42 | - | kernel_X_cpp = kernelMatrix_cpp(X, center, sigma) |
| 42 | + | |
| 43 | library(rbenchmark) | |
| 44 | - | sum(abs([email protected] - kernel_X_cpp) < 1e-12) |
| 44 | + | benchmark(cpp = kernelMatrix_cpp(X, center, sigma), kernlab = kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center), columns=c("test", "replications","elapsed", "relative"), replications=10, order="relative") |