Advertisement
Guest User

Untitled

a guest
Dec 7th, 2019
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.96 KB | None | 0 0
  1. import org.apache.spark.sql.{DataFrame, SparkSession}
  2. import org.apache.spark.sql.functions.{array, explode, udf, to_json, struct}
  3.  
  4. object Classification {
  5.  
  6. val spark: SparkSession = SparkSession.builder().appName("Classifier")
  7. .config("spark.driver.maxResultSize", "2g")
  8. .config("spark.master", "local").getOrCreate()
  9.  
  10. def main(args: Array[String]): Unit = {
  11.  
  12. val testPath = "./mlboot_test.tsv" // 6MB
  13. val trainPath = "./mlboot_train_answers.tsv" // 15 MB
  14.  
  15.  
  16. val dataDF = loadDF()
  17.  
  18. val testDf = joinDF(testPath, dataDF)
  19. val trainDf = joinDF(trainPath, dataDF)
  20.  
  21.  
  22. spark.stop()
  23. }
  24.  
  25. def loadDF(): DataFrame = {
  26. val dataPath = "./mlboot_data.tsv" // 11 GB
  27.  
  28. import spark.implicits._
  29. import org.apache.spark.sql.types._
  30.  
  31. val schema = StructType(Array(
  32. StructField("cuid", StringType, nullable = true),
  33. StructField("cat_feature", IntegerType, nullable = true),
  34. StructField("feature_1", StringType, nullable = true),
  35. StructField("feature_2", StringType, nullable = true),
  36. StructField("feature_3", StringType, nullable = true),
  37. StructField("dt_diff", LongType, nullable = true))
  38. )
  39.  
  40. val tempDataDF = spark.read.format("csv")
  41. .option("header", "false")
  42. .option("delimiter", "\t")
  43. .schema(schema)
  44. .csv(spark.sparkContext.textFile(dataPath, 500).toDS())
  45.  
  46. val dataDF = tempDataDF
  47. .withColumn("features", array(tempDataDF("feature_1"), tempDataDF("feature_2"), tempDataDF("feature_3")))
  48. .withColumn("features", explode($"features"))
  49. .drop("feature_1")
  50. .drop("feature_2")
  51. .drop("feature_3")
  52.  
  53. dataDF
  54. }
  55.  
  56. def joinDF(path: String,
  57. dataFrame: DataFrame): DataFrame = {
  58.  
  59. val df = spark.read.format("csv")
  60. .option("header", "true")
  61. .option("delimiter", "\t")
  62. .load(path)
  63. .join(dataFrame, Seq("cuid"), "inner")
  64.  
  65. df.printSchema()
  66.  
  67. df
  68. }
  69. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement