Advertisement
Guest User

Untitled

a guest
Jun 22nd, 2018
159
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.38 KB | None | 0 0
  1. require(xgboost)
  2. require(viopoints)
  3.  
  4. ### Generate fake data ###
  5. # Generate features
  6. n = 1e4
  7. dat = data.frame(weight = round(rnorm(n, 170, 20), 2),
  8. state = sample(state.abb, n, replace = T),
  9. age = round(rnorm(n, 50, 10)),
  10. income = round(rnorm(n, 70, 20), 3),
  11. stringsAsFactors = F)
  12.  
  13.  
  14. # Generate outcome associated with features
  15. a = (dat$weight < 170 &
  16. dat$state %in% state.abb[1:20] &
  17. dat$age > 50 &
  18. dat$income < 70)
  19.  
  20. b = (dat$weight < 180 &
  21. dat$state %in% state.abb[30:45] &
  22. dat$age < 30 &
  23. dat$income > 100)
  24.  
  25. c = (dat$weight > 200 &
  26. dat$state %in% state.abb[1:50] &
  27. dat$age > 70 &
  28. dat$income > 50)
  29.  
  30. d = (dat$weight > 150 & dat$weight < 190 &
  31. dat$state %in% state.abb[40:50] &
  32. dat$age > 40 & dat$age < 70 &
  33. dat$income > 90)
  34.  
  35. idx = which(a | b | c | d)
  36.  
  37. # Include some noise
  38. outcome = sample(0:1, n, prob= c(.95, .05), replace = T)
  39. outcome[sample(idx, .95*length(idx))] = 1
  40.  
  41.  
  42.  
  43. ### Fit model ###
  44. # Get train/validation indices
  45. idx_train = sample(1:nrow(dat), .7*nrow(dat))
  46. idx_val = setdiff(1:nrow(dat), idx_train)
  47.  
  48. # Format data for xgboost
  49. xgb_dat = dat
  50. xgb_dat$state = as.numeric(as.factor(xgb_dat$state))
  51. xgb_dat = as.matrix(xgb_dat)
  52.  
  53. dtrain = xgb.DMatrix(data = xgb_dat[idx_train,], label = outcome[idx_train])
  54. dval = xgb.DMatrix(data = xgb_dat[idx_val,], label = outcome[idx_val])
  55.  
  56. # Set hyperparameters
  57. xgb_params = list(max.depth = 3,
  58. eta = 0.1,
  59. min_child_weight = 1,
  60. subsample = 0.5,
  61. colsample_bytree = 0.5,
  62. colsample_bylevel = 0.5,
  63. lambda = 0.1,
  64. alpha = 0,
  65. gamma = 0,
  66. max_delta_step = 0,
  67. scale_pos_weight = 1/mean(outcome[idx_train]),
  68. num_parallel_tree = 1)
  69.  
  70. # Train model
  71. bst = xgb.train(data = dtrain,
  72. params = xgb_params,
  73. nrounds = 500,
  74. watchlist = list(train = dtrain, validate = dval),
  75. base = 0.5,
  76. objective = "binary:logistic",
  77. eval_metric = "auc",
  78. early_stopping_rounds = 100)
  79.  
  80.  
  81.  
  82. ### Results ###
  83. # Extract predictions and calc stats
  84. pred_prob = predict(bst, dval)
  85. pred_class = ifelse(pred_prob > 0.5, 1, 0)
  86. val_class = outcome[idx_val]
  87. xgb_imp = xgb.importance(colnames(dat), bst)
  88.  
  89. stats = data.frame(TPR = sum(pred_class == 1 & val_class == 1)/sum(val_class == 1),
  90. FPR = sum(pred_class == 0 & val_class == 1)/sum(val_class == 1),
  91. TNR = sum(pred_class == 0 & val_class == 0)/sum(val_class == 0),
  92. FNR = sum(pred_class == 1 & val_class == 0)/sum(val_class == 0))
  93.  
  94.  
  95. par(mfrow = c(2, 1))
  96. txt_stats = paste(names(stats), round(stats, 3), sep = " = ")
  97. viopoints(pred_prob~val_class, pch = 21, bg = c("Red", "Blue"),
  98. xlab = "Actual Class", ylab = "Pr(Class 1)",
  99. main = c(paste(txt_stats[1:2], collapse = " "),
  100. paste(txt_stats[3:4], collapse = " ")))
  101. abline(h = 0.5, lty = 2, lwd = 2); grid()
  102.  
  103. xgb.plot.importance(xgb_imp, main = "Importance", xlab = "Gain")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement