SHOW:
|
|
- or go back to the newest paste.
1 | library(kernlab) | |
2 | library(Rcpp) | |
3 | - | p = 5 |
3 | + | |
4 | - | b = 1000 |
4 | + | |
5 | p = 1000 | |
6 | b = 3000 | |
7 | - | sigma = 1 |
7 | + | |
8 | center = X[sample(1:N, b),] | |
9 | sigma = 3 | |
10 | ## For windows user | |
11 | # library(inline) | |
12 | - | library(inline) |
12 | + | # settings <- getPlugin("Rcpp") |
13 | - | settings <- getPlugin("Rcpp") |
13 | + | # settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS) |
14 | - | settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS) |
14 | + | # settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS) |
15 | - | settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS) |
15 | + | # do.call(Sys.setenv, settings$env) |
16 | - | do.call(Sys.setenv, settings$env) |
16 | + | |
17 | // [[Rcpp::depends(RcppArmadillo)]] | |
18 | #include <RcppArmadillo.h> | |
19 | using namespace Rcpp; | |
20 | - | #include <omp.h> |
20 | + | |
21 | ||
22 | // [[Rcpp::export]] | |
23 | NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) { | |
24 | uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index; | |
25 | mat X(Xr.begin(), n, Xr.ncol(), false), Center(Centerr.begin(), b, Centerr.ncol(), false), KerX(X*Center.t()); | |
26 | - | omp_set_num_threads(omp_get_max_threads()); |
26 | + | colvec X_sq = sum(square(X), 1) / 2; |
27 | - | uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index; |
27 | + | rowvec Center_sq = (sum(square(Center), 1)).t() / 2; |
28 | - | mat X(Xr.begin(), n, Xr.ncol(), false); |
28 | + | KerX.each_row() -= Center_sq; |
29 | - | mat Center(Centerr.begin(), b, Centerr.ncol(), false); |
29 | + | KerX.each_col() -= X_sq; |
30 | - | mat KerX(n, b); |
30 | + | KerX *= 1 / (sigma * sigma); |
31 | - | #pragma omp parallel private(row_index, col_index) |
31 | + | KerX = exp(KerX); |
32 | - | for (row_index = 0; row_index < n; row_index++) |
32 | + | return wrap(KerX); |
33 | - | { |
33 | + | |
34 | - | #pragma omp for nowait |
34 | + | |
35 | - | for (col_index = 0; col_index < b; col_index++) |
35 | + | t1 = Sys.time() |
36 | - | { |
36 | + | kernel_X_cpp = kernelMatrix_cpp2(X, center, sigma) |
37 | - | KerX(row_index, col_index) = exp(sum(square(X.row(row_index) - Center.row(col_index))) / (-2.0 * sigma * sigma)); |
37 | + | Sys.time() - t1 |
38 | - | } |
38 | + | t1 = Sys.time() |
39 | - | } |
39 | + | |
40 | - | return wrap(KerX); |
40 | + | Sys.time() - t1 |
41 | all.equal([email protected], kernel_X_cpp) | |
42 | - | kernel_X_cpp = kernelMatrix_cpp(X, center, sigma) |
42 | + | |
43 | library(rbenchmark) | |
44 | - | sum(abs([email protected] - kernel_X_cpp) < 1e-12) |
44 | + | benchmark(cpp = kernelMatrix_cpp(X, center, sigma), kernlab = kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center), columns=c("test", "replications","elapsed", "relative"), replications=10, order="relative") |