Advertisement
Guest User

Untitled

a guest
Dec 12th, 2019
153
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.98 KB | None | 0 0
  1. import org.apache.spark.ml.Pipeline
  2. import org.apache.spark.ml.classification.RandomForestClassifier
  3. import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
  4. import org.apache.spark.ml.feature.{Binarizer, ChiSqSelector, CountVectorizer, Normalizer, VectorAssembler, VectorSlicer}
  5. import org.apache.spark.ml.linalg.Vectors
  6. import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
  7. import org.apache.spark.sql.expressions.UserDefinedFunction
  8. import org.apache.spark.sql.{DataFrame, SparkSession}
  9. import org.apache.spark.sql.functions.{array, avg, col, collect_list, explode, udf}
  10. import org.apache.spark.sql.types.DoubleType
  11.  
  12. import scala.util.Sorting.quickSort
  13. import scala.collection._
  14. import org.apache.spark.sql.types._
  15.  
  16. object Classification {
  17.  
  18. val spark: SparkSession = SparkSession.builder().appName("Classifier")
  19. .config("spark.driver.maxResultSize", "5g")
  20. .config("spark.sql.shuffle.partitions", "5")
  21. .config("spark.driver.memory", "3g")
  22. .config("spark.executor.memory ", "3g")
  23. .config("spark.memory.offHeap.size", "4g")
  24. .config("spark.memory.offHeap.enabled", "true")
  25. .config("spark.master", "local").getOrCreate()
  26.  
  27. def main(args: Array[String]): Unit = {
  28.  
  29.  
  30.  
  31. val testPath = "./mlboot_test.tsv" // 6MB
  32. val trainPath = "./mlboot_train_answers.tsv" // 15 MB
  33.  
  34. val loadDF = loadData()
  35.  
  36. val testData = joinDF(testPath, loadDF)
  37. val trainData = joinDF(trainPath, loadDF)
  38. .withColumn("label", col("target").cast(DoubleType))
  39. .drop("target")
  40.  
  41.  
  42. testData.show(6, truncate = false)
  43. // trainData.show(6, truncate = false)
  44.  
  45. classification(testData, trainData)
  46.  
  47. spark.stop()
  48. }
  49.  
  50. def joinDF(path: String,
  51. forJoinDF: DataFrame): DataFrame = {
  52. val tempDataDF = spark.read.format("csv")
  53. .option("header", "true")
  54. .option("delimiter", "\t")
  55. .load(path)
  56.  
  57. val joinedDF =
  58. tempDataDF.join(forJoinDF, Seq("cuid"), "inner")
  59. .drop("cuid")
  60.  
  61. joinedDF.printSchema()
  62.  
  63. joinedDF
  64. }
  65.  
  66. def loadData(): DataFrame = {
  67.  
  68. import spark.implicits._
  69.  
  70. val path = "./xaa.tsv" // 10GB
  71.  
  72. val schema = StructType(Array(
  73. StructField("cuid", StringType, nullable = true),
  74. StructField("cat_feat", DoubleType, nullable = true),
  75. StructField("feature_1", StringType, nullable = true),
  76. StructField("feature_2", StringType, nullable = true),
  77. StructField("feature_3", StringType, nullable = true),
  78. StructField("date_diff", DoubleType, nullable = true))
  79. )
  80.  
  81. val dataDF = spark.read.format("csv")
  82. .option("header", "false")
  83. .option("delimiter", "\t")
  84. .schema(schema)
  85. .csv(spark.sparkContext.textFile(path, 500).toDS())
  86.  
  87. val combineMaps = new CombineMaps[Int, Double](IntegerType, DoubleType, _ + _)
  88.  
  89. val df = dataDF
  90. .withColumn("features_j", array(dataDF("feature_1"), dataDF("feature_2"), dataDF("feature_3")))
  91. .withColumn("features_json", explode(col("features_j")))
  92. .rdd.map(row => {
  93. import org.json4s._
  94. import org.json4s.jackson.JsonMethods._
  95. implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
  96. val cuid: String = row.getAs[String]("cuid")
  97. val cat_feat: Double = row.getAs[Double]("cat_feat")
  98. val jsonString: String = row.getAs[String]("features_json")
  99. val map: Map[Int, Double] = parse(jsonString).extract[Map[Int, Double]]
  100. val dateDiff: Double = row.getAs[Double]("date_diff")
  101. (cuid, cat_feat, map, dateDiff)
  102. }).toDF("cuid", "cat_features", "map_features", "dats_diff")
  103. .groupBy("cuid")
  104. .agg(
  105. collect_list("cat_features") as "cat_array",
  106. avg("dats_diff") as "dt_diff",
  107. combineMaps(col("map_features")))
  108. .withColumn("sparse_vector", mapToSparse(col("combinemaps(map_features)")))
  109. .withColumn("cat_vector", convertArrayToVector(col("cat_array")))
  110. .drop("combinemaps(map_features)")
  111. .drop("cat_array")
  112.  
  113. df
  114. }
  115.  
  116. def classification(testDF: DataFrame,
  117. trainDF: DataFrame): Unit = {
  118.  
  119. val evaluator = new BinaryClassificationEvaluator()
  120. .setLabelCol("label")
  121. .setMetricName("areaUnderROC")
  122.  
  123. val binarizer = new Binarizer()
  124. .setInputCol("sparse_vector")
  125. .setOutputCol("binarized_vector_features")
  126. .setThreshold(4)
  127.  
  128. val vectorAssembler = new VectorAssembler()
  129. .setInputCols(Array("dt_diff", "cat_vector", "binarized_vector_features"))
  130. .setOutputCol("rf_features")
  131.  
  132. val chiSqSelector = new ChiSqSelector()
  133. .setLabelCol("label")
  134. .setFeaturesCol("rf_features")
  135. .setOutputCol("features")
  136.  
  137. // create the trainer and set its parameters
  138. val randomForestClassifier = new RandomForestClassifier()
  139. .setLabelCol("label")
  140. .setFeaturesCol("features")
  141.  
  142. val paramGrid = new ParamGridBuilder()
  143. .addGrid(randomForestClassifier.maxBins, Array(15, 25, 35, 45))
  144. .addGrid(randomForestClassifier.maxDepth, Array(4, 6, 8, 10, 12, 14, 16))
  145. .addGrid(randomForestClassifier.numTrees, Array(12, 15, 18, 20))
  146. .addGrid(randomForestClassifier.impurity, Array("entropy", "gini"))
  147. .build()
  148.  
  149. val pipeline = new Pipeline()
  150. .setStages(Array(binarizer, vectorAssembler, chiSqSelector ,randomForestClassifier))
  151.  
  152. val crossValidator = new CrossValidator()
  153. .setEstimator(pipeline)
  154. .setEvaluator(evaluator)
  155. .setEstimatorParamMaps(paramGrid)
  156. .setNumFolds(5)
  157.  
  158. val cvModel = crossValidator.fit(trainDF)
  159.  
  160. val cvPredictionDF = cvModel.transform(testDF)
  161.  
  162. val accuracy = evaluator.evaluate(cvPredictionDF)
  163.  
  164. println("Accuracy (ROC) with cross validation = " + accuracy)
  165. }
  166.  
  167. def convertArrayToVector: UserDefinedFunction =
  168. udf((features: mutable.WrappedArray[Double]) => Vectors.dense(features.toArray))
  169.  
  170. def mapToSparse: UserDefinedFunction =
  171. udf((map: Map[Int, Double]) => Vectors.sparse(map.keys.max + 1, map.toSeq))
  172. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement