Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. rm(list = ls())
  2. library(grf)
  3. library(plyr)
  4. library(ranger)
  5. library(openssl)
  6. options(digits = 22)
  7.  
  8. # Generate data:
  9. # Linear signal (bad for trees) and lots of noise.
  10. set.seed(12345)
  11. n = 500
  12. p = 10
  13. m = 20
  14. num.trees = 2000
  15. results = matrix(0, 5, 4)
  16. designs = c("mostly_linear", "nonlinear1", "nonlinear2", "piecewise_linear")
  17.  
  18. for (i in seq(4)) {
  19.  
  20. design = designs[i]
  21. print(design)
  22.  
  23. # ------ GENERATE ORIGINAL DATA --------
  24. if (design == "mostly_linear") {
  25. X = matrix(rnorm(n*p), n, p)
  26. Y = pmax(X[,1], 0) + pmin(X[,3], 0) + apply(X[,4:10], 1, sum) + rnorm(n)
  27. } else if (design == "nonlinear1") {
  28. X = matrix(runif(n*p), n, p)
  29. zeta1 <- 1 + 1/(1 + exp(-20 * (X[, 1] - (1/3))))
  30. zeta2 <- 1 + 1/(1 + exp(-20 * (X[, 2] - (1/3))))
  31. Y <- zeta1*zeta2 + rnorm(n)
  32. } else if (design == "nonlinear2") {
  33. X = matrix(runif(n*p), n, p)
  34. zeta1 <- 2/(1 + exp(-12 * (X[, 1] - (1/2))))
  35. zeta2 <- 2/(1 + exp(-12 * (X[, 2] - (1/2))))
  36. Y <- zeta1*zeta2 #+ rnorm(n)
  37. } else if (design == "piecewise_linear") {
  38. X <- matrix(rnorm(n*p), n, p)
  39. b1 <- matrix(rnorm(p), ncol = 1)
  40. b2 <- matrix(rnorm(p), ncol = 1)
  41. mu <- X %*% b1 * (X[, 1] < (-0.4)) + X %*% b2 * (X[, 1] >= -0.4)
  42. Y <- mu + rnorm(n)
  43. }
  44.  
  45. # ------ FIT A SINGLE FOREST ON ORIGINAL DATA --------
  46. Y1 <- Y
  47. Y2 <- Y + 1e-10 * rnorm(n)
  48.  
  49. rf1 = regression_forest(X, Y1, seed = 12345, num.trees = num.trees)
  50. rf2 = regression_forest(X, Y2, seed = 12345, num.trees = num.trees)
  51.  
  52. yhat1_oob = predict(rf1)$predictions
  53. yhat2_oob = predict(rf2)$predictions
  54.  
  55. w1 = get_sample_weights(rf1)
  56. w2 = get_sample_weights(rf2)
  57. w_diff = as.matrix(w1 - w2)
  58.  
  59. yhat1_direct = as.numeric(w1 %*% Y1)
  60. yhat2_direct = as.numeric(w2 %*% Y2)
  61.  
  62. yhat1_cross = as.numeric(w1 %*% Y2)
  63. yhat2_cross = as.numeric(w2 %*% Y1)
  64.  
  65. results[,i] = c(
  66. mean(mean(abs(w_diff[w_diff != 0.]))),
  67. mean(abs(yhat1_oob - yhat1_direct)),
  68. mean(abs(yhat2_oob - yhat2_direct)),
  69. mean(abs(yhat1_oob - yhat1_cross)),
  70. mean(abs(yhat2_oob - yhat2_cross))
  71. )
  72. }
  73. colnames(results) = designs
  74. rownames(results) = c("w1 vs w2",
  75. "yhat1_oob vs yhat1_direct",
  76. "yhat2_oob vs yhat2_direct",
  77. "yhat1_oob vs yhat1_cross",
  78. "yhat2_oob vs yhat2_cross")
  79.  
  80. print(results)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement