Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import com.microsoft.ml.spark.LightGBMClassificationModel
- import org.apache.spark.ml.classification.RandomForestClassificationModel
- def getFeatureImportances(inputContainer: PipelineModelContainer): (String, String) = {
- val transformer = inputContainer.pipelineModel.stages.last
- val featureImportancesVector = inputContainer.params match {
- case RandomForestParameters(numTrees, treeDepth, featureTransformer) => transformer.asInstanceOf[RandomForestClassificationModel].featureImportances
- case LightGBMParameters(treeDepth, numLeaves, iterations, featureTransformer) => transformer.asInstanceOf[LightGBMClassificationModel].getFeatureImportances("split")
- }
- val colNames = inputContainer.featureColNames
- val sortedFeatures = (colNames zip featureImportancesVector.toArray).sortWith(_._2 > _._2).zipWithIndex
- value toArray is not a member of java.io.Serializable
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement