View difference between Paste ID: X4jnRn4A and AHEwSTZ5
SHOW: | | - or go back to the newest paste.
1
library(kernlab)
2
N = 10000
3
p = 5
4
b = 1000
5
X = matrix(rnorm(N*p), ncol = p)
6
center = X[sample(1:N, b),]
7
sigma = 1
8
kernel_X <- kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center)
9
10
library(Rcpp)
11
library(RcppArmadillo)
12
library(inline)
13
settings <- getPlugin("Rcpp")
14
settings$env$PKG_CXXFLAGS <- paste('-fopenmp', settings$env$PKG_CXXFLAGS)
15
settings$env$PKG_LIBS <- paste('-fopenmp -lgomp', settings$env$PKG_LIBS)
16
do.call(Sys.setenv, settings$env)
17
sourceCpp(code = '
18
// [[Rcpp::depends(RcppArmadillo)]]
19
#include <RcppArmadillo.h>
20
#include <omp.h>
21
using namespace Rcpp;
22
using namespace arma;
23
// [[Rcpp::export]]
24
25
NumericMatrix kernelMatrix_cpp(NumericMatrix Xr, NumericMatrix Centerr, double sigma) {
26
	omp_set_num_threads(omp_get_max_threads());
27
	uword n = Xr.nrow(), b = Centerr.nrow(), row_index, col_index;
28
	mat X(Xr.begin(), n, Xr.ncol(), false);
29
	mat Center(Centerr.begin(), b, Centerr.ncol(), false);
30
	mat KerX(n, b);
31
	#pragma omp parallel private(row_index, col_index)
32
	for (row_index = 0; row_index < n; row_index++)
33
	{
34
		#pragma omp for nowait
35
		for (col_index = 0; col_index < b; col_index++)
36
		{
37
			KerX(row_index, col_index) = exp(sum(square(X.row(row_index) - Center.row(col_index))) / (-2.0 * sigma * sigma));
38
		}
39
	}
40
    return wrap(KerX);
41
}')
42
kernel_X_cpp = kernelMatrix_cpp(X, center, sigma)
43
all.equal([email protected], kernel_X_cpp)
44
sum(abs([email protected] - kernel_X_cpp) < 1e-12)
45
46
library(rbenchmark)
47
benchmark(kernelMatrix_cpp(X, center, sigma), kernelMatrix(rbfdot(sigma=1/(2*sigma^2)), X, center), columns=c("test", "replications","elapsed", "relative"), replications=50, order="relative")