Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.apache.spark.ml.feature.Normalizer
- import org.apache.spark.ml.classification.LinearSVC
- import org.apache.spark.ml.classification.{LinearSVC, OneVsRest}
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- import org.apache.spark.ml.linalg.Vectors
- //Save the model, maybe we don't need it now to save, but who knows
- model.write.overwrite().save("/tmp/chisqSelector-model_00")
- val part2Model = PipelineModel.load("tmp/chisqSelector-model_00")
- //model.write.overwrite().save("tmp/chisqSelector-model_00")
- //Get the data from the model:
- val dfModel = part2Model.transform(dfNew).select("categoryIndex", "category" ,"selectedFeatures")
- val normalizer = new Normalizer().setInputCol("selectedFeatures").setOutputCol("normFeatures").setP(2.0)
- val normalizedDf = normalizer.transform(dfModel)
- val Array(train, test) = normalizedDf.randomSplit(Array(0.8, 0.2), seed = 32)
- val lsvc = new LinearSVC().setFeaturesCol("normFeatures")
- .setLabelCol("categoryIndex")
- .setMaxIter(10) .setRegParam(0.1)
- //val part3Pipe = new Pipeline().setStages(Array(part2Model, normalizer, ovr))
- val ovr = new OneVsRest().setClassifier(lsvc)
- val ovrModel = ovr.setFeaturesCol("normFeatures").setLabelCol("categoryIndex").fit(train)
- val predictions = ovrModel.transform(test)
- val evaluator = new MulticlassClassificationEvaluator().setLabelCol("categoryIndex")
- // default metric is f1.setMetricName("f1")
- val f1Score = evaluator.evaluate(predictions)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement