Advertisement
Guest User

Untitled

a guest
Dec 7th, 2019
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.95 KB | None | 0 0
  1. import org.apache.spark.ml.feature.VectorAssembler
  2. import org.apache.spark.sql.{DataFrame, SparkSession}
  3. import org.apache.spark.sql.functions.struct
  4.  
  5. object Classification {
  6.  
  7. val spark: SparkSession = SparkSession.builder().appName("Classifier")
  8. .config("spark.driver.maxResultSize", "2g")
  9. .config("spark.master", "local").getOrCreate()
  10.  
  11. def main(args: Array[String]): Unit = {
  12.  
  13. val testPath = "./mlboot_test.tsv" // 6MB
  14. val trainPath = "./mlboot_train_answers.tsv" // 15 MB
  15.  
  16.  
  17. val dataDF = loadDF()
  18.  
  19. val testDf = joinDF(testPath, dataDF).show(6, truncate = false)
  20. val trainDf = joinDF(trainPath, dataDF).show(6, truncate = false)
  21.  
  22. spark.stop()
  23. }
  24.  
  25. def loadDF(): DataFrame = {
  26. val dataPath = "./mlboot_data.tsv" // 11 GB
  27.  
  28. import spark.implicits._
  29. import org.apache.spark.sql.types._
  30.  
  31. val schema = StructType(Array(
  32. StructField("cuid", StringType, nullable = true),
  33. StructField("cat_feature", IntegerType, nullable = true),
  34. StructField("feature_1", StringType, nullable = true),
  35. StructField("feature_2", StringType, nullable = true),
  36. StructField("feature_3", StringType, nullable = true),
  37. StructField("dt_diff", LongType, nullable = true))
  38. )
  39.  
  40. val tempDataDF = spark.read.format("csv")
  41. .option("header", "false")
  42. .option("delimiter", "\t")
  43. .schema(schema)
  44. .csv(spark.sparkContext.textFile(dataPath, 500).toDS())
  45.  
  46. val dataDF = tempDataDF
  47. .withColumn("features", struct(tempDataDF("feature_1"), tempDataDF("feature_2"), tempDataDF("feature_3")))
  48. .drop("feature_1")
  49. .drop("feature_2")
  50. .drop("feature_3")
  51.  
  52. dataDF
  53. }
  54.  
  55. def joinDF(path: String,
  56. dataFrame: DataFrame): DataFrame = {
  57.  
  58. val df = spark.read.format("csv")
  59. .option("header", "true")
  60. .option("delimiter", "\t")
  61. .load(path)
  62. .join(dataFrame, Seq("cuid"), "inner")
  63.  
  64. df
  65. }
  66. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement