View difference between Paste ID: u5jqtc4h and X4jnRn4A
SHOW: | | - or go back to the newest paste.
1
library(kernlab)
2
library(Rcpp)
3-
p = 5
3+
4-
b = 1000
4+
5
p = 1000
6
b = 3000
7-
sigma = 1
7+
8
center = X[sample(1:N, b),]
9
sigma = 3
10
## For windows user
11
# library(inline)
12-
library(inline)
12+
# settings <- getPlugin("Rcpp")
13-
settings <- getPlugin("Rcpp")
13+
# settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
14-
settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
14+
# settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
15-
settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
15+
# do.call(Sys.setenv, settings$env)
16-
do.call(Sys.setenv, settings$env)
16+
17
// [[Rcpp::depends(RcppArmadillo)]]
18
#include <RcppArmadillo.h>
19
using namespace Rcpp;
20-
#include <omp.h>
20+
21
22
// [[Rcpp::export]]
23
NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
24
  uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
25
  mat X(Xr.begin(), n, Xr.ncol(), false), Center(Centerr.begin(), b, Centerr.ncol(), false), KerX(X*Center.t());
26-
	omp_set_num_threads(omp_get_max_threads());
26+
  colvec X_sq = sum(square(X), 1) / 2;
27-
	uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
27+
  rowvec Center_sq = (sum(square(Center), 1)).t() / 2;
28-
	mat X(Xr.begin(), n, Xr.ncol(), false);
28+
  KerX.each_row() -= Center_sq;
29-
	mat Center(Centerr.begin(), b, Centerr.ncol(), false);
29+
  KerX.each_col() -= X_sq;
30-
	mat KerX(n, b);
30+
  KerX *= 1 / (sigma * sigma);
31-
	#pragma omp parallel private(row_index, col_index)
31+
  KerX = exp(KerX);
32-
	for (row_index = 0; row_index < n; row_index++)
32+
  return wrap(KerX);
33-
	{
33+
34-
		#pragma omp for nowait
34+
35-
		for (col_index = 0; col_index < b; col_index++)
35+
t1 = Sys.time()
36-
		{
36+
kernel_X_cpp = kernelMatrix_cpp2(X, center, sigma)
37-
			KerX(row_index, col_index) = exp(sum(square(X.row(row_index) - Center.row(col_index))) / (-2.0 * sigma * sigma));
37+
 Sys.time() - t1
38-
		}
38+
 t1 = Sys.time()
39-
	}
39+
40-
    return wrap(KerX);
40+
Sys.time() - t1
41
all.equal([email protected], kernel_X_cpp)
42-
kernel_X_cpp = kernelMatrix_cpp(X, center, sigma)
42+
43
library(rbenchmark)
44-
sum(abs([email protected] - kernel_X_cpp) < 1e-12)
44+
benchmark(cpp = kernelMatrix_cpp(X, center, sigma), kernlab = kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center), columns=c("test", "replications","elapsed", "relative"), replications=10, order="relative")