Advertisement
Guest User

Untitled

a guest
Sep 26th, 2016
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.83 KB | None | 0 0
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17.  
  18. // scalastyle:off println
  19. package org.apache.spark.examples
  20.  
  21. import java.io.PrintWriter
  22.  
  23. import scopt.OptionParser
  24.  
  25. import org.apache.spark.{SparkConf, SparkContext}
  26. import org.apache.spark.examples.mllib.AbstractParams
  27. import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
  28. import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
  29. import org.apache.spark.mllib.util.MLUtils
  30.  
  31. /*
  32. Usage: SVMWSGD <input file> <epochs> <stepSize> <regularizer>
  33. <batch size> <validation data file>
  34. The first 5 arguments are necessary, the last argument is used only when
  35. there is a need to predict values and calculate accuracy of the model.
  36. Convert and merge the X.csv and Y.csv (flame csv) files to XYSpark.txt (libsvm file)
  37. by using the csv2libsvm.py script.
  38. For validation convert and merge XY.csv and YV.csv to XYVSpark.txt by using the script.
  39. */
  40.  
  41. object SVMWSGD {
  42.  
  43. case class Params(
  44. input: String = null,
  45. nr_epochs: Int = 100,
  46. stepsize: Double = 1.0,
  47. batch_size: Int = 65536,
  48. regularizer: Double = 0.1,
  49. validation_data: String = null,
  50. result_file: String = null) extends AbstractParams[Params]
  51.  
  52. def main(args: Array[String]): Unit = {
  53.  
  54. val conf = new SparkConf().setAppName("SVMWSGD")
  55. val sc = new SparkContext(conf)
  56. val defaultParams = Params()
  57.  
  58. val parser = new OptionParser[Params]("SVMWSGD") {
  59. head("SVM: an example app for linear regression.")
  60. opt[Int]("nr_epochs")
  61. .text(s"number of epochs, default ${defaultParams.nr_epochs}")
  62. .action((x, c) => c.copy(nr_epochs = x))
  63. opt[Double]("stepsize")
  64. .text(s"initial step size, default: ${defaultParams.stepsize}")
  65. .action((x, c) => c.copy(stepsize = x))
  66. opt[Double]("regularizer")
  67. .text(s"regularization parameter, default: ${defaultParams.regularizer}")
  68. .action((x, c) => c.copy(regularizer = x))
  69. opt[Int]("batch_size")
  70. .text(s"batch size, default: ${defaultParams.batch_size}")
  71. .action((x, c) => c.copy(batch_size = x))
  72. arg[String]("<input>")
  73. .required()
  74. .text("input paths to labeled examples in LIBSVM format")
  75. .action((x, c) => c.copy(input = x))
  76. opt[String]("validation_data")
  77. .text("validation data paths to labeled examples in LIBSVM format," +
  78. s" default: ${defaultParams.validation_data}")
  79. .action((x, c) => c.copy(validation_data = x))
  80. opt[String]("result_file")
  81. .text("path to file where the results need to be stored in fcsv format," +
  82. s" default: ${defaultParams.result_file}")
  83. .action((x, c) => c.copy(result_file = x))
  84. note(
  85. """
  86. |For example, the following command runs this app on a synthetic dataset:
  87. |
  88. | bin/spark-submit --class org.apache.spark.examples.SVMWSGD
  89. | examples/target/scala-*/spark-examples-*.jar
  90. | data/mllib/sample_linear_regression_data.txt
  91. """.
  92. stripMargin)
  93. }
  94. parser.parse(args, defaultParams).map { params =>
  95. run(params, sc)
  96. } getOrElse {
  97. sys.exit(1)
  98. }
  99. }
  100.  
  101. def run(params: Params, sc: SparkContext) {
  102.  
  103. // val conf = new SparkConf().setAppName("SVMWSGD")
  104. // val sc = new SparkContext(conf)
  105.  
  106. // $example on$
  107. // Load training data in LIBSVM format.
  108. val training = MLUtils.loadLibSVMFile(sc, params.input).cache()
  109. // Split data into training (60%) and test (40%).
  110. // val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
  111. // val training = splits(0).cache()
  112. // val test = splits(1)
  113. val numSamples = training.count().toInt
  114. val numIterations = (numSamples / params.batch_size) * params.nr_epochs
  115. val miniBatchFraction = (params.batch_size.toDouble / numSamples)
  116.  
  117. val start_time = System.nanoTime
  118. // Run training algorithm to build the model
  119. // params: input, numiterations, stepSize, regParam, miniBatchFraction
  120. val model = SVMWithSGD.train(training, numIterations, params.stepsize,
  121. params.regularizer, miniBatchFraction)
  122.  
  123. // Clear the default threshold.
  124. model.clearThreshold()
  125.  
  126. val passed_time = (System.nanoTime - start_time) / 1000000
  127. println(s"Finished SVM with SGD training in $passed_time milliseconds.")
  128.  
  129. // Print to console
  130. println("bias/intercept: " + model.intercept)
  131. println("weights: " + model.weights)
  132.  
  133.  
  134. // Print to file
  135. if (params.result_file != null) {
  136. new PrintWriter(s"data/mllib/${params.result_file}") {
  137. write(model.weights.size + "n"
  138. + model.weights.toString().replace("[", "").replace("]", ""));
  139. close
  140. }
  141. }
  142.  
  143. // Compute raw scores on the test set.
  144. // load validation data
  145. if (params.validation_data != null) {
  146. println("validating and computing scores")
  147.  
  148. val test = MLUtils.loadLibSVMFile(sc, params.validation_data).cache()
  149.  
  150. val scoreAndLabels = test.map { point =>
  151. val score = model.predict(point.features)
  152. (score, point.label)
  153. }
  154.  
  155. val predicted_labels = scoreAndLabels.map {point => point._2}.collect()
  156. println("accuracy:")
  157. predicted_labels.foreach(println)
  158. println("attempt to extract labels")
  159. val actual_labels = test.map { point => point.label}.collect()
  160. println("actual labels:")
  161. actual_labels.foreach(println)
  162.  
  163. val accuracy = actual_labels.zip(predicted_labels).map
  164. {case (x, y) => if (x == y) { 1 } else { 0 } }
  165.  
  166. println("correct values = " + accuracy.count(_ == 1) +
  167. "total length = " + accuracy.length)
  168. println("accuracy = " + accuracy.count(_ == 1).toDouble / accuracy.length)
  169.  
  170. // Get evaluation metrics.
  171. val metrics = new BinaryClassificationMetrics(scoreAndLabels)
  172. val auROC = metrics.areaUnderROC()
  173. println("Area under ROC = " + auROC)
  174. }
  175.  
  176. // Save and load model
  177. model.save(sc, "target/tmp/scalaSVMWithSGDModel")
  178. val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel")
  179. // $example off$
  180.  
  181. sc.stop()
  182. }
  183. }
  184. // scalastyle:on println
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement