Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- // scalastyle:off println
- package org.apache.spark.examples
- import java.io.PrintWriter
- import scopt.OptionParser
- import org.apache.spark.{SparkConf, SparkContext}
- import org.apache.spark.examples.mllib.AbstractParams
- import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
- import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
- import org.apache.spark.mllib.util.MLUtils
- /*
- Usage: SVMWSGD <input file> <epochs> <stepSize> <regularizer>
- <batch size> <validation data file>
- The first 5 arguments are necessary, the last argument is used only when
- there is a need to predict values and calculate accuracy of the model.
- Convert and merge the X.csv and Y.csv (flame csv) files to XYSpark.txt (libsvm file)
- by using the csv2libsvm.py script.
- For validation convert and merge XY.csv and YV.csv to XYVSpark.txt by using the script.
- */
- object SVMWSGD {
- case class Params(
- input: String = null,
- nr_epochs: Int = 100,
- stepsize: Double = 1.0,
- batch_size: Int = 65536,
- regularizer: Double = 0.1,
- validation_data: String = null,
- result_file: String = null) extends AbstractParams[Params]
- def main(args: Array[String]): Unit = {
- val conf = new SparkConf().setAppName("SVMWSGD")
- val sc = new SparkContext(conf)
- val defaultParams = Params()
- val parser = new OptionParser[Params]("SVMWSGD") {
- head("SVM: an example app for linear regression.")
- opt[Int]("nr_epochs")
- .text(s"number of epochs, default ${defaultParams.nr_epochs}")
- .action((x, c) => c.copy(nr_epochs = x))
- opt[Double]("stepsize")
- .text(s"initial step size, default: ${defaultParams.stepsize}")
- .action((x, c) => c.copy(stepsize = x))
- opt[Double]("regularizer")
- .text(s"regularization parameter, default: ${defaultParams.regularizer}")
- .action((x, c) => c.copy(regularizer = x))
- opt[Int]("batch_size")
- .text(s"batch size, default: ${defaultParams.batch_size}")
- .action((x, c) => c.copy(batch_size = x))
- arg[String]("<input>")
- .required()
- .text("input paths to labeled examples in LIBSVM format")
- .action((x, c) => c.copy(input = x))
- opt[String]("validation_data")
- .text("validation data paths to labeled examples in LIBSVM format," +
- s" default: ${defaultParams.validation_data}")
- .action((x, c) => c.copy(validation_data = x))
- opt[String]("result_file")
- .text("path to file where the results need to be stored in fcsv format," +
- s" default: ${defaultParams.result_file}")
- .action((x, c) => c.copy(result_file = x))
- note(
- """
- |For example, the following command runs this app on a synthetic dataset:
- |
- | bin/spark-submit --class org.apache.spark.examples.SVMWSGD
- | examples/target/scala-*/spark-examples-*.jar
- | data/mllib/sample_linear_regression_data.txt
- """.
- stripMargin)
- }
- parser.parse(args, defaultParams).map { params =>
- run(params, sc)
- } getOrElse {
- sys.exit(1)
- }
- }
- def run(params: Params, sc: SparkContext) {
- // val conf = new SparkConf().setAppName("SVMWSGD")
- // val sc = new SparkContext(conf)
- // $example on$
- // Load training data in LIBSVM format.
- val training = MLUtils.loadLibSVMFile(sc, params.input).cache()
- // Split data into training (60%) and test (40%).
- // val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
- // val training = splits(0).cache()
- // val test = splits(1)
- val numSamples = training.count().toInt
- val numIterations = (numSamples / params.batch_size) * params.nr_epochs
- val miniBatchFraction = (params.batch_size.toDouble / numSamples)
- val start_time = System.nanoTime
- // Run training algorithm to build the model
- // params: input, numiterations, stepSize, regParam, miniBatchFraction
- val model = SVMWithSGD.train(training, numIterations, params.stepsize,
- params.regularizer, miniBatchFraction)
- // Clear the default threshold.
- model.clearThreshold()
- val passed_time = (System.nanoTime - start_time) / 1000000
- println(s"Finished SVM with SGD training in $passed_time milliseconds.")
- // Print to console
- println("bias/intercept: " + model.intercept)
- println("weights: " + model.weights)
- // Print to file
- if (params.result_file != null) {
- new PrintWriter(s"data/mllib/${params.result_file}") {
- write(model.weights.size + "n"
- + model.weights.toString().replace("[", "").replace("]", ""));
- close
- }
- }
- // Compute raw scores on the test set.
- // load validation data
- if (params.validation_data != null) {
- println("validating and computing scores")
- val test = MLUtils.loadLibSVMFile(sc, params.validation_data).cache()
- val scoreAndLabels = test.map { point =>
- val score = model.predict(point.features)
- (score, point.label)
- }
- val predicted_labels = scoreAndLabels.map {point => point._2}.collect()
- println("accuracy:")
- predicted_labels.foreach(println)
- println("attempt to extract labels")
- val actual_labels = test.map { point => point.label}.collect()
- println("actual labels:")
- actual_labels.foreach(println)
- val accuracy = actual_labels.zip(predicted_labels).map
- {case (x, y) => if (x == y) { 1 } else { 0 } }
- println("correct values = " + accuracy.count(_ == 1) +
- "total length = " + accuracy.length)
- println("accuracy = " + accuracy.count(_ == 1).toDouble / accuracy.length)
- // Get evaluation metrics.
- val metrics = new BinaryClassificationMetrics(scoreAndLabels)
- val auROC = metrics.areaUnderROC()
- println("Area under ROC = " + auROC)
- }
- // Save and load model
- model.save(sc, "target/tmp/scalaSVMWithSGDModel")
- val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel")
- // $example off$
- sc.stop()
- }
- }
- // scalastyle:on println
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement