SHOW:
|
|
- or go back to the newest paste.
| 1 | library(kernlab) | |
| 2 | N = 10000 | |
| 3 | p = 5 | |
| 4 | b = 1000 | |
| 5 | X = matrix(rnorm(N*p), ncol = p) | |
| 6 | center = X[sample(1:N, b),] | |
| 7 | sigma = 1 | |
| 8 | kernel_X <- kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center) | |
| 9 | ||
| 10 | library(Rcpp) | |
| 11 | library(RcppArmadillo) | |
| 12 | library(inline) | |
| 13 | settings <- getPlugin("Rcpp")
| |
| 14 | settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
| |
| 15 | settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
| |
| 16 | do.call(Sys.setenv, settings$env) | |
| 17 | sourceCpp(code = ' | |
| 18 | // [[Rcpp::depends(RcppArmadillo)]] | |
| 19 | #include <RcppArmadillo.h> | |
| 20 | #include <omp.h> | |
| 21 | using namespace Rcpp; | |
| 22 | using namespace arma; | |
| 23 | // [[Rcpp::export]] | |
| 24 | ||
| 25 | NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
| |
| 26 | omp_set_num_threads(omp_get_max_threads()); | |
| 27 | uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index; | |
| 28 | mat X(Xr.begin(), n, Xr.ncol(), false); | |
| 29 | mat Center(Centerr.begin(), b, Centerr.ncol(), false); | |
| 30 | mat KerX(n, b); | |
| 31 | #pragma omp parallel private(row_index, col_index) | |
| 32 | for (row_index = 0; row_index < n; row_index++) | |
| 33 | {
| |
| 34 | #pragma omp for nowait | |
| 35 | for (col_index = 0; col_index < b; col_index++) | |
| 36 | {
| |
| 37 | KerX(row_index, col_index) = exp(sum(square(X.row(row_index) - Center.row(col_index))) / (-2.0 * sigma * sigma)); | |
| 38 | } | |
| 39 | } | |
| 40 | return wrap(KerX); | |
| 41 | }') | |
| 42 | kernel_X_cpp = kernelMatrix_cpp(X, center, sigma) | |
| 43 | all.equal([email protected], kernel_X_cpp) | |
| 44 | sum(abs([email protected] - kernel_X_cpp) < 1e-12) | |
| 45 | ||
| 46 | library(rbenchmark) | |
| 47 | benchmark(kernelMatrix_cpp(X, center, sigma), kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center), columns=c("test", "replications","elapsed", "relative"), replications=50, order="relative") |