Advertisement
Guest User

Untitled

a guest
Mar 22nd, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.58 KB | None | 0 0
  1. # Copyright (c) Microsoft Corporation. All rights reserved.
  2.  
  3. # Third Party Programs. This software enables you to obtain software applications from other sources.
  4. # Those applications are offered and distributed by third parties under their own license terms.
  5. # Microsoft is not developing, distributing or licensing those applications to you, but instead,
  6. # as a convenience, enables you to use this software to obtain those applications directly from
  7. # the application providers.
  8. # By using the software, you acknowledge and agree that you are obtaining the applications directly
  9. # from the third party providers and under separate license terms, and that it is your responsibility to locate,
  10. # understand and comply with those license terms.
  11. # Microsoft grants you no license rights for third-party software or applications that is obtained using this software.
  12.  
  13.  
  14. ##PBI_R_VISUAL: VIZGAL_DTREE Graphical display of Decision Tree
  15. # Computes and visualizes a decision tree used for classification or piecewise regression
  16. #
  17. # INPUT:
  18. # The input dataset should include at least two columns. First column is a dependent variable,
  19. # the rest of columns are independend variables.
  20. # EXAMPLES:
  21. # #for R environment
  22. # dataset<-mtcars #assign dataset
  23. # source("visGal_corrplot.R") #create graphics
  24. #
  25. # WARNINGS:
  26. # This visual intended to be used for classification tasks. It was not tested for regression trees.
  27. #
  28. # CREATION DATE: 06/01/2016
  29. #
  30. # LAST UPDATE: 08/09/2016
  31. #
  32. # VERSION: 0.0.1
  33. #
  34. # R VERSION TESTED: 3.2.2
  35. #
  36. # AUTHOR: B. Efraty (boefraty@microsoft.com)
  37. #
  38. # REFERENCES: https://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html
  39.  
  40. #PBI_EXAMPLE_DATASET for debugging purposes
  41. if(!exists( "dataset" ))
  42. {
  43. data( iris ) #Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species
  44. dataset = iris[, c(5, 1, 2, 3, 4)]
  45. }
  46.  
  47. ############ User Parameters #########
  48.  
  49. ##PBI_PARAM: Should warnings messages be displayed?
  50. #Type:logical, Default:TRUE, Range:NA, PossibleValues:NA, Remarks: NA
  51. showWarnings = FALSE
  52.  
  53. ##PBI_PARAM: the maximum depth of the final tree [1, 30]
  54. #Type:positive integer, Default:20, Range:[1, 30], PossibleValues:NA, Remarks: The tree of maxDepth is not promised
  55. maxDepth = 20
  56.  
  57. ###############Library Declarations###############
  58. libraryRequireInstall = function(packageName, ...)
  59. {
  60. if(!require(packageName, character.only = TRUE))
  61. warning(paste("*** The package: '", packageName, "' was not installed ***",sep=""))
  62. }
  63.  
  64. libraryRequireInstall("rpart")
  65. libraryRequireInstall("rpart.plot")
  66. libraryRequireInstall("RColorBrewer")
  67.  
  68. ###### Inner parameters and definitions ###################
  69.  
  70. ##PBI_PARAM: Should info text be displayed in subtitle?
  71. #Type:logical, Default:TRUE, Range:NA, PossibleValues:NA, Remarks: NA
  72. showInfo = TRUE
  73.  
  74. ##PBI_PARAM: Complexity parameter.
  75. # Any split that does not decrease the overall lack of fit by a factor of complexity is not attempted.
  76. #Type:numeric, Default:1e-05, Range:[0, 1], PossibleValues:NA, Remarks: If complexity and xval are 0 tree is maximal
  77. complexity = 1e-05
  78.  
  79. ##PBI_PARAM: the minimum number of observations in any terminal (leaf) node
  80. #Type:positive integer, Default:2, Range:[1, 100], PossibleValues:NA, Remarks: NA
  81. minBucket = 2
  82.  
  83. ##PBI_PARAM: indicator if xval parameter is to be found automatically
  84. #Type:bool, Default:TRUE, Range:NA, PossibleValues:NA, Remarks: NA
  85. autoXval = TRUE
  86.  
  87. ##PBI_PARAM: number of cross-validations, used only if autoXval = FALSE
  88. #Type:integer, Default:10, Range:[0, 1000], PossibleValues:NA, Remarks: Can not be larger than number of records
  89. xval = 10
  90.  
  91. ##PBI_PARAM: the random number generator (RNG) state for random number generation
  92. #Type: numeric, Default:42, Range:NA, PossibleValues:NA, Remarks: NA
  93. randSeed = 42
  94.  
  95. ##PBI_PARAM: minimum required samples (rows in data table)
  96. #Type: positive integer, Default:10, Range:[5, 100], PossibleValues:NA, Remarks: NA
  97. minRows = 10
  98.  
  99. ##PBI_PARAM: maximum attempts to construct tree with optimal depth > 1
  100. #Type: positive integer, Default:10, Range:[1, 50], PossibleValues:NA, Remarks: NA
  101. maxNumAttempts = 10
  102.  
  103. ###############Internal functions definitions#################
  104.  
  105. #automaticly select the number of cross-validations
  106. autoXvalFunc <- function(numRows)
  107. {
  108. breaks = c(0, 5, 10, 100, 500, 1000, 10000, Inf)
  109. xvals = c(0, 2, 10, 100, 10, 5, 2)
  110. return( xvals[cut(numRows, breaks = breaks )] )
  111. }
  112.  
  113. #select best CP by cptable (for optimal tree pruning)
  114. optimalCPbyXError <- function(cptable, delta = 0.00001)
  115. {
  116. opt = data.frame(ind = NaN, CP = NaN, xerror = NaN)
  117. xerror<-cptable$xerror
  118. relErr<-cptable$rel
  119. if(is.null(xerror))
  120. xerror<-relErr
  121. CP<-cptable$CP
  122. thresh<-min(xerror) + (max(xerror) - min(xerror))*delta
  123. opt$ind<-min(seq(1, length(xerror))[xerror <= thresh])
  124. opt$CP<-CP[opt$ind]
  125. opt$xerror<-ifelse(is.null(cptable$xerror), NA, xerror[opt$ind])
  126. opt$relErr<-relErr[opt$ind]
  127. return(opt)
  128. }
  129.  
  130. #format numbers to fixed number of digits after the floating point
  131. d2form = function(x, p = 2) {if(is.numeric(x)) format(round(x, p), nsmall = p)}
  132.  
  133. #automatically convert columns with few unique values to factors
  134. convertCol2factors<-function(data, minCount = 3)
  135. {
  136. for (c in 1:ncol(data))
  137. if(is.logical(data[, c])){
  138. data[, c] = as.factor(data[, c])
  139. }else{
  140. uc<-unique(data[, c])
  141. if(length(uc) <= minCount)
  142. data[, c] = as.factor(data[, c])
  143. }
  144. return(data)
  145. }
  146.  
  147. #compute root node error
  148. rootNodeError<-function(labels)
  149. {
  150. ul<-unique(labels)
  151. g<-NULL
  152. for (u in ul) g = c(g, sum(labels == u))
  153. return(1-max(g)/length(labels))
  154. }
  155.  
  156. # this function is almost identical to fancyRpartPlot{rattle}
  157. # it is duplicated here because the call for library(rattle) may trigger GTK load,
  158. # which may be missing on user's machine
  159. replaceFancyRpartPlot<-function (model, main = "", sub = "", palettes, ...)
  160. {
  161.  
  162. num.classes <- length(attr(model, "ylevels"))
  163.  
  164. default.palettes <- c("Greens", "Blues", "Oranges", "Purples",
  165. "Reds", "Greys")
  166. if (missing(palettes))
  167. palettes <- default.palettes
  168.  
  169. missed <- setdiff(1:6, seq(length(palettes)))
  170. palettes <- c(palettes, default.palettes[missed])
  171. numpals <- 6
  172. palsize <- 5
  173. pals <- c(RColorBrewer::brewer.pal(9, palettes[1])[1:5],
  174. RColorBrewer::brewer.pal(9, palettes[2])[1:5], RColorBrewer::brewer.pal(9,
  175. palettes[3])[1:5], RColorBrewer::brewer.pal(9, palettes[4])[1:5],
  176. RColorBrewer::brewer.pal(9, palettes[5])[1:5], RColorBrewer::brewer.pal(9,
  177. palettes[6])[1:5])
  178. if (model$method == "class") {
  179. yval2per <- -(1:num.classes) - 1
  180. per <- apply(model$frame$yval2[, yval2per], 1, function(x) x[1 +
  181. x[1]])
  182. }
  183. else {
  184. per <- model$frame$yval/max(model$frame$yval)
  185. }
  186. per <- as.numeric(per)
  187. if (model$method == "class")
  188. col.index <- ((palsize * (model$frame$yval - 1) + trunc(pmin(1 +
  189. (per * palsize), palsize)))%%(numpals * palsize))
  190. else col.index <- round(per * (palsize - 1)) + 1
  191. col.index <- abs(col.index)
  192. if (model$method == "class")
  193. extra <- 104
  194. else extra <- 101
  195. rpart.plot::prp(model, type = 2, extra = extra, box.col = pals[col.index],
  196. nn = TRUE, varlen = 0, faclen = 0, shadow.col = "grey",
  197. fallen.leaves = TRUE, branch.lty = 3, ...)
  198. title(main = main, sub = sub)
  199. }
  200.  
  201.  
  202.  
  203.  
  204. ###############Upfront input correctness validations (where possible)#################
  205.  
  206. pbiWarning<-""
  207. pbiInfo<-""
  208.  
  209. dataset <- dataset[complete.cases(dataset[, 1]), ] #remove rows with corrupted labels
  210. dataset = convertCol2factors(dataset)
  211. nr <- nrow( dataset )
  212. nc <- ncol( dataset )
  213. nl <- length( unique(dataset[, 1]))
  214.  
  215. goodDim <- (nr >=minRows && nc >= 2 && nl >= 2)
  216.  
  217.  
  218. ##############Main Visualization script###########
  219. set.seed(randSeed)
  220. opt = NULL
  221. dtree = NULL
  222.  
  223. if(autoXval)
  224. xval<-autoXvalFunc(nr)
  225.  
  226. dNames <- names(dataset)
  227. X <- as.vector(dNames[-1])
  228.  
  229. form <- as.formula(paste('`', dNames[1], '`', "~ .", sep = ""))
  230.  
  231. # Run the model
  232. if(goodDim)
  233. {
  234. for(a in 1:maxNumAttempts)
  235. {
  236. dtree <- rpart(form, dataset, control = rpart.control(minbucket = minBucket, cp = complexity, maxdepth = maxDepth, xval = xval)) #large tree
  237. rooNodeErr <- rootNodeError(dataset[, 1])
  238. opt <- optimalCPbyXError(as.data.frame(dtree$cptable))
  239.  
  240. dtree<-prune(dtree, cp = opt$CP)
  241. if(opt$ind > 1)
  242. break;
  243. }
  244. }
  245.  
  246. #info for classifier
  247. if( showInfo && !is.null(dtree) && dtree$method == 'class')
  248. pbiInfo <- paste("Rel error = ", d2form(opt$relErr * rooNodeErr),
  249. "; CVal error = ", d2form(opt$xerror * rooNodeErr),
  250. "; Root error = ", d2form(rooNodeErr),
  251. ";cp = ", d2form(opt$CP, 3), sep = "")
  252.  
  253. if(goodDim && opt$ind>1)
  254. {
  255. #fancyRpartPlot(dtree, sub = pbiInfo)
  256. replaceFancyRpartPlot(dtree, sub = pbiInfo)
  257.  
  258.  
  259. }else{
  260. if( showWarnings )
  261. pbiWarning <- ifelse(goodDim, paste("The tree depth is zero. Root error = ", d2form(rooNodeErr), sep = ""),
  262. "Wrong data dimensionality" )
  263.  
  264. plot.new()
  265. title( main = NULL, sub = pbiWarning, outer = FALSE, col.sub = "gray40" )
  266. }
  267. remove("dataset")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement