Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public static void main(String[] args) {
- SparkSession spark = SparkSession.builder()
- .appName("Simple Application")
- .config("spark.master", "local")
- .getOrCreate();
- Dataset<Row> dataFrame = getDataset();
- Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 1234L);
- Dataset<Row> train = splits[0];
- Dataset<Row> test = splits[1];
- System.out.println("Start learning...");
- NaiveBayes nb = new NaiveBayes();
- NaiveBayesModel model = nb.fit(train);
- Dataset<Row> predictions = model.transform(test);
- predictions.show();
- // predictions.show(Integer.MAX_VALUE / 2);
- MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
- .setLabelCol("label")
- .setPredictionCol("prediction")
- .setMetricName("accuracy");
- double accuracy = evaluator.evaluate(predictions);
- System.out.println("Test set accuracy = " + accuracy);
- spark.stop();
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement