Advertisement
Dundre32

Part3 - 00

Apr 26th, 2020
819
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 1.51 KB | None | 0 0
  1. import org.apache.spark.ml.feature.Normalizer
  2. import org.apache.spark.ml.classification.LinearSVC
  3. import org.apache.spark.ml.classification.{LinearSVC, OneVsRest}
  4. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  5. import org.apache.spark.ml.linalg.Vectors
  6.  
  7. //Save the model, maybe we don't need it now to save, but who knows
  8. model.write.overwrite().save("/tmp/chisqSelector-model_00")
  9.  
  10. val part2Model = PipelineModel.load("tmp/chisqSelector-model_00")
  11. //model.write.overwrite().save("tmp/chisqSelector-model_00")
  12. //Get the data from the model:
  13. val dfModel = part2Model.transform(dfNew).select("categoryIndex", "category" ,"selectedFeatures")
  14.  
  15. val normalizer = new Normalizer().setInputCol("selectedFeatures").setOutputCol("normFeatures").setP(2.0)
  16. val normalizedDf = normalizer.transform(dfModel)
  17.  
  18. val Array(train, test) = normalizedDf.randomSplit(Array(0.8, 0.2), seed = 32)
  19.  
  20. val lsvc = new LinearSVC().setFeaturesCol("normFeatures")
  21.                           .setLabelCol("categoryIndex")
  22.                           .setMaxIter(10) .setRegParam(0.1)
  23.  
  24. //val part3Pipe = new Pipeline().setStages(Array(part2Model, normalizer, ovr))
  25. val ovr = new OneVsRest().setClassifier(lsvc)
  26. val ovrModel = ovr.setFeaturesCol("normFeatures").setLabelCol("categoryIndex").fit(train)
  27.  
  28. val predictions = ovrModel.transform(test)
  29. val evaluator = new MulticlassClassificationEvaluator().setLabelCol("categoryIndex")
  30. // default metric is f1.setMetricName("f1")
  31.  
  32. val f1Score = evaluator.evaluate(predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement