Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2017
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.40 KB | None | 0 0
  1. val rawData = sc.textFile("file:///search/dje/spark-ml/train_noheader.tsv")
  2. val records = rawData.map(line => line.split("\t"))
  3. records.first()
  4.  
  5. import org.apache.spark.mllib.regression.LabeledPoint
  6. import org.apache.spark.mllib.linalg.Vectors
  7. val data = records.map { r =>
  8. val trimmed = r.map(_.replaceAll("\"",""))
  9. val label = trimmed(r.size - 1).toInt
  10. val features = trimmed.slice(4, r.size - 1).map(d => if (d == "?") 0.0 else d.toDouble)
  11. LabeledPoint(label, Vectors.dense(features))
  12. }
  13.  
  14. data.cache
  15. val numData = data.count
  16.  
  17. val nbData = records.map { r =>
  18. val trimmed = r.map(_.replaceAll("\"",""))
  19. val label = trimmed(r.size - 1).toInt
  20. val features = trimmed.slice(4, r.size - 1).map(d => if (d == "?") 0.0 else d.toDouble).map(d => if (d < 0) 0.0 else d )
  21. LabeledPoint(label, Vectors.dense(features))
  22. }
  23.  
  24. import org.apache.spark.mllib.classification.LogisticRegressionWithSGD
  25. import org.apache.spark.mllib.classification.SVMWithSGD
  26. import org.apache.spark.mllib.classification.NaiveBayes
  27. import org.apache.spark.mllib.tree.DecisionTree
  28. import org.apache.spark.mllib.tree.configuration.Algo
  29. import org.apache.spark.mllib.tree.impurity.Entropy
  30.  
  31. val numIterations = 10
  32. val maxTreeDepth = 5
  33.  
  34. val lrModel = LogisticRegressionWithSGD.train(data, numIterations)
  35. val svmModel = SVMWithSGD.train(data, numIterations)
  36. val nbMobel = NaiveBayes.train(nbdata)
  37. val dtModel = DecisionTree.train(data, Algo.Classification, Entropy, maxTreeDepth)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement