Advertisement
Guest User

Untitled

a guest
Feb 28th, 2017
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.18 KB | None | 0 0
  1. import scala.reflect.runtime.universe._
  2.  
  3. def getObjectInstance(mirror: Mirror, clsName: String): ModuleMirror = {
  4. val module = mirror.staticModule(clsName)
  5. mirror.reflectModule(module)
  6. }
  7.  
  8.  
  9. def reflectSQLFunction(clsName: String)(funcName: String)(params: Any*): Column = {
  10. val mirror = runtimeMirror(getClass.getClassLoader)
  11. val instanceMirror: InstanceMirror = mirror.reflect(getObjectInstance(mirror, clsName).instance)
  12. val paramsType = params.map(getType)
  13. val methods = instanceMirror.symbol.typeSignature.member(TermName(funcName)).asTerm.alternatives
  14. val targetMethod = methods.filter{
  15. method =>
  16. val methodArgs = method.asInstanceOf[MethodSymbol].paramLists.head.map{
  17. symbol =>
  18. symbol.typeSignature.resultType.toString
  19. }
  20. if(methodArgs.equals(paramsType)) true else false
  21. }
  22. if(targetMethod.size == 1) {
  23. instanceMirror.reflectMethod(targetMethod.head.asMethod)(params: _*).asInstanceOf[Column]
  24. } else {
  25. throw new IllegalStateException(s"Only exist one function for params's type: ${paramsType}, but" +
  26. s" find ${targetMethod.size} functions.")
  27. }
  28. }
  29.  
  30. def constructContinueAggregation(dataframe: DataFrame,
  31. continuousAggregationSpec: ContinuousAggregationSpec): DataFrame = {
  32.  
  33.  
  34. val eventTimeField = continuousAggregationSpec.timeField
  35. val windowLength = continuousAggregationSpec.window.length
  36. val windowSlide = continuousAggregationSpec.window.slide
  37. val waterMarkTime = continuousAggregationSpec.watermark
  38. val aggregationFunctions = continuousAggregationSpec.aggregationFunctions
  39. val groupByFields = continuousAggregationSpec.groupByFields
  40. val filterExpr = continuousAggregationSpec.filterExpr
  41. val addedFields = continuousAggregationSpec.addedFields
  42.  
  43. val selectAndFilterSQL = constructSelectAndFilterSQL(filterExpr, addedFields)
  44. dataframe.createOrReplaceTempView(AppConstant.DefaultTable)
  45.  
  46. val tmpSelectAndFilterDF = dataframe.sparkSession.sql(selectAndFilterSQL)
  47. val selectAndFilterDF = tmpSelectAndFilterDF.withColumn(eventTimeField,
  48. tmpSelectAndFilterDF(eventTimeField).cast(TimestampType))
  49.  
  50. val aggregatedDF = (groupByFields.isDefined, aggregationFunctions.isDefined) match {
  51.  
  52. case (_, false) =>
  53. throw new IllegalArgumentException("Must has one or more aggregation Functions at least.")
  54.  
  55. case (_, true) =>
  56. val sparkAggFunctions: (String) => (Seq[Any]) => Column = reflectSQLFunction(AppConstant.SqlObjectFunction)(_)
  57. val aggregationColumns: Seq[Column] = aggregationFunctions.head.map{
  58. aggregationFunction =>
  59. val funcName = aggregationFunction.func
  60. val alias = aggregationFunction.alias
  61. val fieldColumn = Column(parseExpressionFunc(aggregationFunction.field))
  62. val resColumn = if(funcName.isDefined) {
  63. // otherParam is true or double type value
  64. val otherParam: Seq[AnyVal] = aggregationFunction.otherParam match {
  65. case Some(param) => param.trim match {
  66. case "true" => Seq(true)
  67. case "false" => Seq(false)
  68. case x => Seq(x.toDouble)
  69. }
  70. case None => Seq()
  71. }
  72. sparkAggFunctions(funcName.head)(Seq(fieldColumn) ++ otherParam)
  73. }else {
  74. fieldColumn
  75. }
  76.  
  77. if(alias.isDefined) {
  78. resColumn.as(alias.head)
  79. } else {
  80. resColumn
  81. }
  82. }
  83. val headColumn = aggregationColumns.head
  84. val tailColumns = aggregationColumns.tail.toArray
  85.  
  86. val groupByDF = if (groupByFields.isDefined) {
  87. val groupByColumns = groupByFields.head.map{
  88. field =>
  89. log.info(s"Group By field is ${field}")
  90. new Column(field)
  91. }
  92. selectAndFilterDF
  93. .withWatermark(eventTimeField, waterMarkTime)
  94. .groupBy((List(window(new Column(eventTimeField), windowLength, windowSlide)) ++ groupByColumns): _*)
  95. }else {
  96. selectAndFilterDF
  97. .withWatermark(eventTimeField, waterMarkTime)
  98. .groupBy(window(new Column(eventTimeField), windowLength, windowSlide))
  99. }
  100. groupByDF.agg(headColumn, tailColumns: _*)
  101. }
  102. aggregatedDF
  103. }
  104.  
  105. def checkAndGetFinalSchema(structType: StructType, continueAggregationSpec: ContinuousAggregationSpec): Seq[Attribute] = {
  106.  
  107. val attributes = structType.fields.map { field =>
  108. AttributeReference(field.name, field.dataType)()
  109. }.toSeq
  110.  
  111. val logicalPlan = LocalRelation(attributes)
  112. val encoder: ExpressionEncoder[Row] = RowEncoder(structType)
  113. val inputDataFrame: Dataset[Row] = Dataset(sparkSession, logicalPlan)(encoder).toDF()
  114. val aggregation: DataFrame = constructContinueAggregation(inputDataFrame, continueAggregationSpec)
  115. .withColumn("windowStart", new Column("window.start").cast(LongType))
  116. .withColumn("windowEnd", new Column("window.end").cast(LongType))
  117. .drop("window")
  118.  
  119. val resolvedPlan = aggregation.logicalPlan
  120. log.info("final schema is: ")
  121. log.info(resolvedPlan.output.toString())
  122. resolvedPlan.output
  123. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement