Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- rm(list = ls())
- library(grf)
- library(plyr)
- library(ranger)
- library(openssl)
- options(digits = 22)
- # Generate data:
- # Linear signal (bad for trees) and lots of noise.
- set.seed(12345)
- n = 500
- p = 10
- m = 20
- num.trees = 2000
- results = matrix(0, 5, 4)
- designs = c("mostly_linear", "nonlinear1", "nonlinear2", "piecewise_linear")
- for (i in seq(4)) {
- design = designs[i]
- print(design)
- # ------ GENERATE ORIGINAL DATA --------
- if (design == "mostly_linear") {
- X = matrix(rnorm(n*p), n, p)
- Y = pmax(X[,1], 0) + pmin(X[,3], 0) + apply(X[,4:10], 1, sum) + rnorm(n)
- } else if (design == "nonlinear1") {
- X = matrix(runif(n*p), n, p)
- zeta1 <- 1 + 1/(1 + exp(-20 * (X[, 1] - (1/3))))
- zeta2 <- 1 + 1/(1 + exp(-20 * (X[, 2] - (1/3))))
- Y <- zeta1*zeta2 + rnorm(n)
- } else if (design == "nonlinear2") {
- X = matrix(runif(n*p), n, p)
- zeta1 <- 2/(1 + exp(-12 * (X[, 1] - (1/2))))
- zeta2 <- 2/(1 + exp(-12 * (X[, 2] - (1/2))))
- Y <- zeta1*zeta2 #+ rnorm(n)
- } else if (design == "piecewise_linear") {
- X <- matrix(rnorm(n*p), n, p)
- b1 <- matrix(rnorm(p), ncol = 1)
- b2 <- matrix(rnorm(p), ncol = 1)
- mu <- X %*% b1 * (X[, 1] < (-0.4)) + X %*% b2 * (X[, 1] >= -0.4)
- Y <- mu + rnorm(n)
- }
- # ------ FIT A SINGLE FOREST ON ORIGINAL DATA --------
- Y1 <- Y
- Y2 <- Y + 1e-10 * rnorm(n)
- rf1 = regression_forest(X, Y1, seed = 12345, num.trees = num.trees)
- rf2 = regression_forest(X, Y2, seed = 12345, num.trees = num.trees)
- yhat1_oob = predict(rf1)$predictions
- yhat2_oob = predict(rf2)$predictions
- w1 = get_sample_weights(rf1)
- w2 = get_sample_weights(rf2)
- w_diff = as.matrix(w1 - w2)
- yhat1_direct = as.numeric(w1 %*% Y1)
- yhat2_direct = as.numeric(w2 %*% Y2)
- yhat1_cross = as.numeric(w1 %*% Y2)
- yhat2_cross = as.numeric(w2 %*% Y1)
- results[,i] = c(
- mean(mean(abs(w_diff[w_diff != 0.]))),
- mean(abs(yhat1_oob - yhat1_direct)),
- mean(abs(yhat2_oob - yhat2_direct)),
- mean(abs(yhat1_oob - yhat1_cross)),
- mean(abs(yhat2_oob - yhat2_cross))
- )
- }
- colnames(results) = designs
- rownames(results) = c("w1 vs w2",
- "yhat1_oob vs yhat1_direct",
- "yhat2_oob vs yhat2_direct",
- "yhat1_oob vs yhat1_cross",
- "yhat2_oob vs yhat2_cross")
- print(results)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement