Advertisement
Guest User

Untitled

a guest
Aug 19th, 2019
114
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.89 KB | None | 0 0
  1. import com.microsoft.ml.spark.LightGBMClassificationModel
  2. import org.apache.spark.ml.classification.RandomForestClassificationModel
  3.  
  4. def getFeatureImportances(inputContainer: PipelineModelContainer): (String, String) = {
  5.  
  6. val transformer = inputContainer.pipelineModel.stages.last
  7.  
  8. val featureImportancesVector = inputContainer.params match {
  9.  
  10. case RandomForestParameters(numTrees, treeDepth, featureTransformer) => transformer.asInstanceOf[RandomForestClassificationModel].featureImportances
  11. case LightGBMParameters(treeDepth, numLeaves, iterations, featureTransformer) => transformer.asInstanceOf[LightGBMClassificationModel].getFeatureImportances("split")
  12. }
  13.  
  14. val colNames = inputContainer.featureColNames
  15. val sortedFeatures = (colNames zip featureImportancesVector.toArray).sortWith(_._2 > _._2).zipWithIndex
  16.  
  17. value toArray is not a member of java.io.Serializable
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement