SHARE
TWEET

Untitled

a guest Aug 19th, 2019 79 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import com.microsoft.ml.spark.LightGBMClassificationModel
  2. import org.apache.spark.ml.classification.RandomForestClassificationModel
  3.  
  4. def getFeatureImportances(inputContainer: PipelineModelContainer): (String, String) = {
  5.  
  6.     val transformer = inputContainer.pipelineModel.stages.last
  7.  
  8.     val featureImportancesVector = inputContainer.params match {
  9.  
  10.         case RandomForestParameters(numTrees, treeDepth, featureTransformer) => transformer.asInstanceOf[RandomForestClassificationModel].featureImportances
  11.         case LightGBMParameters(treeDepth, numLeaves, iterations, featureTransformer) => transformer.asInstanceOf[LightGBMClassificationModel].getFeatureImportances("split")
  12. }
  13.  
  14.     val colNames = inputContainer.featureColNames
  15.     val sortedFeatures = (colNames zip featureImportancesVector.toArray).sortWith(_._2 > _._2).zipWithIndex
  16.      
  17. value toArray is not a member of java.io.Serializable
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top