Advertisement
Guest User

Untitled

a guest
Oct 17th, 2018
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.16 KB | None | 0 0
  1. package com.vng.zing.zudm_zingmp3_suggestion.relistening_prediction
  2.  
  3. import org.apache.spark.sql.SQLContext
  4. import org.apache.spark.sql.Dataset
  5. import org.apache.spark.sql.Row
  6. import com.vng.zing.zudm_zingmp3_suggestion.spark_io.DFGenericInputOutputTransformation
  7. import org.apache.spark.sql.types.{ StructType, LongType, IntegerType, BooleanType, StringType, TimestampType, DoubleType }
  8. import org.apache.spark.sql.functions._
  9. import org.apache.spark.sql.expressions.Window
  10. abstract class FeatureExtractionFromLog extends DFGenericInputOutputTransformation {
  11. def filter: String
  12. def transformation(sc: SQLContext, input: Dataset[Row]): Dataset[Row] = {
  13.  
  14. import sc.implicits._
  15. input.sqlContext.sparkSession.conf.set("spark.sql.autoBroadcastJoinThreshold", 400000000)
  16. // input.sqlContext.sql("set spark.sql.shuffle.partitions=50")
  17.  
  18. val cur_date = new java.sql.Timestamp(System.currentTimeMillis())
  19.  
  20. println("Last timestamp: " + cur_date)
  21.  
  22. val ds = input
  23. .filter(filter)
  24. .withColumn("pos", $"pos".cast(DoubleType))
  25. .withColumn("duration", $"duration".cast(DoubleType))
  26. .withColumn("fullyListen", ($"pos" / $"duration" >= 0.7).cast(IntegerType))
  27. .withColumn("ratio", $"pos" / $"duration")
  28. .withColumn("date", from_unixtime($"timestamp" / 1000, "yyyy-MM-dd HH:mm:ss"))
  29.  
  30. .withColumn("dateTmp", date_add(lit(cur_date), 1))
  31. .withColumn("daysDiff", datediff($"dateTmp", $"date"))
  32. .withColumn("isDownloaded", $"isDownloaded".cast(IntegerType))
  33. .withColumn("age", $"age".cast(IntegerType))
  34. .withColumn("type", $"type".cast(IntegerType))
  35. .withColumn("volume", $"volume".cast(IntegerType))
  36. .filter($"type" === 1 || $"type" === 2)
  37. .na.fill(0, Seq("isDownloaded"))
  38. .na.fill(50, Seq("volume"))
  39. .na.fill(25, Seq("age"))
  40. .na.fill(0, Seq("fullyListen"))
  41. .na.fill(0.5, Seq("ratio"))
  42. .select(
  43. $"user_id".as("uid"),
  44. $"id".as("sid"),
  45. $"isDownloaded",
  46. $"volume",
  47. $"age",
  48. $"fullyListen",
  49. $"ratio",
  50. $"daysDiff")
  51.  
  52. val byLastListen = Window.partitionBy($"uid", $"sid").orderBy($"daysDiff".asc)
  53.  
  54. val group_last_listen_ratio = ds.withColumn("rn", row_number.over(byLastListen))
  55. .where($"rn" === 1)
  56. .drop("rn")
  57. .select($"uid", $"sid", $"ratio".as("lastListenRatio"))
  58.  
  59. val stdDiff = udf((days: Seq[Int]) => {
  60.  
  61. if (days.length == 1) {
  62. 0.0
  63. } else {
  64. val daysSet = days.toSet.toList
  65.  
  66. var dist: List[Int] = List()
  67.  
  68. var i = 0;
  69.  
  70. for (i <- 0 until daysSet.length - 1) {
  71. dist = Math.abs(daysSet(i) - daysSet(i + 1)) :: dist
  72. }
  73.  
  74. val mean = 1.0 / dist.length * dist.reduce(_ + _)
  75. var std = 0.0
  76.  
  77. for (i <- 0 until dist.length) {
  78. std += (mean - dist(i)) * (mean - dist(i))
  79. }
  80.  
  81. std = Math.sqrt(std * 1.0 / (days.length))
  82. std
  83. }
  84. })
  85.  
  86. val group_listen_time_std = ds.groupBy("uid").agg(collect_set("daysDiff").as("listDaysDiff"))
  87. .withColumn("stdDaysDiff", stdDiff($"listDaysDiff"))
  88. .select("uid", "stdDaysDiff")
  89.  
  90. val group_count_listen = ds.groupBy("uid", "sid")
  91. .count()
  92. .withColumnRenamed("count", "count_listen")
  93. .filter($"count_listen" < 200)
  94.  
  95. val group_avg_count_listen = group_count_listen.groupBy("uid")
  96. .count()
  97. .withColumnRenamed("count", "total_listen")
  98.  
  99. val group_avg_count_listen_song = group_count_listen.join(group_avg_count_listen, Seq("uid"), "inner")
  100. .withColumn("avg_listen_ratio", $"count_listen" / $"total_listen")
  101. .select("uid", "sid", "count_listen", "avg_listen_ratio")
  102.  
  103. val group_count_songs = ds.groupBy("uid")
  104. .agg(countDistinct("sid").as("count_song"))
  105. .filter($"count_song" < 300)
  106.  
  107. val group_nearest_listen = ds.groupBy("uid", "sid")
  108. .agg(min("daysDiff").as("nearest_time"))
  109.  
  110. val group_longest_listen = ds.groupBy("uid", "sid")
  111. .agg((max("daysDiff") - min("daysDiff")).as("longest_time"))
  112.  
  113. val group_last_listen = ds.groupBy("uid")
  114. .agg(min("daysDiff").as("last_time"))
  115.  
  116. val group_diff_days = ds.groupBy("uid", "sid")
  117. .agg(countDistinct("daysDiff").as("noDays"))
  118.  
  119. val group_listen_ratio = ds.groupBy("uid", "sid")
  120. .agg(avg("ratio").as("listenRatio"))
  121.  
  122. val group_fully_listen_ratio = ds.groupBy("uid", "sid")
  123. .agg((sum("fullyListen") / count("fullyListen")).as("fullyListenRatio"))
  124.  
  125. val group_age = ds.groupBy("uid")
  126. .agg(max("age").as("age"))
  127. .withColumn("age", when($"age" < 12, 12).when($"age" > 50, 50).otherwise($"age"))
  128.  
  129. val group_is_downloaded = ds.groupBy("uid", "sid")
  130. .agg(max("isDownloaded").as("isDownloaded"))
  131.  
  132. val group_volume = ds.groupBy("uid", "sid")
  133. .agg(mean("volume").as("volume"))
  134.  
  135. val ds_merged_1 = group_avg_count_listen_song
  136. .join(group_nearest_listen, Seq("uid", "sid"), "inner")
  137. .join(group_longest_listen, Seq("uid", "sid"), "inner")
  138. val ds_merged_2 = ds_merged_1
  139. .join(group_diff_days, Seq("uid", "sid"), "inner")
  140. .join(group_listen_ratio, Seq("uid", "sid"), "inner")
  141. val ds_merged_3 = ds_merged_2
  142. .join(group_last_listen_ratio, Seq("uid", "sid"), "inner")
  143. .join(group_fully_listen_ratio, Seq("uid", "sid"), "inner")
  144. val ds_merged_4 = ds_merged_3
  145. .join(group_volume, Seq("uid", "sid"), "inner")
  146. .join(group_count_songs, Seq("uid"), "inner")
  147. val ds_merged_5 = ds_merged_4
  148. .join(group_age, Seq("uid"), "inner")
  149. .join(group_last_listen, Seq("uid"), "inner")
  150. val ds_merged = ds_merged_5
  151. .join(group_listen_time_std, Seq("uid"), "inner")
  152. .join(group_is_downloaded, Seq("uid", "sid"), "inner")
  153.  
  154. ds_merged.coalesce(35)
  155. }
  156.  
  157. }
  158.  
  159. object FeatureExtractionFirstPart extends FeatureExtractionFromLog {
  160. def filter = "user_id < 1010000000"
  161. }
  162.  
  163. object FeatureExtractionSecondPart extends FeatureExtractionFromLog {
  164. def filter = "user_id >= 1010000000 and user_id < 1020000000"
  165. }
  166.  
  167. object FeatureExtractionThirdPart extends FeatureExtractionFromLog {
  168. def filter = "user_id >= 1020000000"
  169. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement