Advertisement
Guest User

Untitled

a guest
Dec 6th, 2019
123
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.41 KB | None | 0 0
  1. import org.apache.spark.SparkContext
  2. import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext, SparkSession}
  3. import org.apache.spark.sql.functions.broadcast
  4.  
  5.  
  6. object Classification {
  7.  
  8. def main(args: Array[String]): Unit = {
  9.  
  10. val dataPath = "./mlboot_data.tsv" // 11 GB
  11. val testPath = "./mlboot_test.tsv" // 6MB
  12. val trainPath = "./mlboot_train_answers.tsv" // 15 MB
  13.  
  14.  
  15. val spark = SparkSession.builder().appName("Classifier")
  16. .config("spark.driver.maxResultSize", "11g")
  17. .config("spark.sql.broadcastTimeout", "36000")
  18. .config("spark.master", "local").getOrCreate()
  19.  
  20. val dataDF = spark.read.format("csv")
  21. .option("header", "false")
  22. .option("delimiter", "\t")
  23. .load(dataPath)
  24. .withColumnRenamed("_c0", "cuid")
  25. .withColumnRenamed("_c1", "cat_feature")
  26. .withColumnRenamed("_c2", "feature_1")
  27. .withColumnRenamed("_c3", "feature_2")
  28. .withColumnRenamed("_c4", "feature_3")
  29. .withColumnRenamed("_c5", "dt_diff")
  30. .repartition(900)
  31.  
  32.  
  33. val testDF = spark.read.format("csv")
  34. .option("header", "true")
  35. .option("delimiter", "\t")
  36. .load(testPath)
  37.  
  38. val trainDF = spark.read.format("csv")
  39. .option("header", "true")
  40. .option("delimiter", "\t")
  41. .load(trainPath)
  42. .join(broadcast(dataDF),Seq("cuid"),"inner")
  43. .show(6, truncate = false)
  44.  
  45. spark.stop()
  46. }
  47. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement