Guest User

Untitled

a guest
Nov 18th, 2017
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.12 KB | None | 0 0
  1. package com.meituan.hotel.oe.snappydata.mbl
  2.  
  3. import org.apache.spark.sql.catalyst.dsl.expressions._
  4. import org.apache.spark.sql.catalyst.expressions._
  5. import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
  6. import org.apache.spark.sql.types._
  7.  
  8. case class PriceMBL(children: Seq[Expression]) extends DeclarativeAggregate {
  9.  
  10. private lazy val startTimeInSeconds: Int = children.head.eval().asInstanceOf[Int]
  11. private lazy val endTimeInSeconds: Int = children(1).eval().asInstanceOf[Int]
  12. private lazy val stepInSeconds: Int = children(2).eval().asInstanceOf[Int]
  13. private lazy val durationInSeconds: Int = children(3).eval().asInstanceOf[Int]
  14.  
  15. private final def numRowsReal: Int = (endTimeInSeconds - startTimeInSeconds) / stepInSeconds
  16.  
  17. // 行数,设置个最大值,防止OOM
  18. private final def numRows: Int = if (numRowsReal > 1024 * 3) 1024 * 3 else numRowsReal
  19.  
  20. // 每行3个点
  21. private final def numPoints: Int = numRows * 3
  22.  
  23. // 最后两个元素用于填充startTimeInSeconds和stepInSeconds
  24. private final val arraySize = numPoints + 2
  25.  
  26. // 如果duartion小于step,说明有部分点是不需要计算的,目前看并不需要,所以没有实现该逻辑,小于和等于是相同的结果
  27. private final def pointSpan: Int = if (durationInSeconds <= stepInSeconds) 1 else durationInSeconds / stepInSeconds
  28.  
  29. override def inputTypes: Seq[DataType] = Seq(
  30. IntegerType, // 1: startTimeInSeconds
  31. IntegerType, // 2: endTimeInSeconds
  32. IntegerType, // 3: stepInSeconds
  33. IntegerType, // 4: durationInSeconds
  34. IntegerType, // 5: crawTime
  35. IntegerType, // 6: compPrice
  36. IntegerType // 7: mtPrice
  37. ) ++ (if (children.length == 8) Seq(IntegerType /* weight */) else Nil) // 8: 第八个参数"权重"可选
  38.  
  39. override def dataType: DataType = ArrayType(IntegerType)
  40.  
  41. override def nullable: Boolean= false
  42.  
  43. private lazy val points = AttributeReference("points", ArrayType(IntegerType), nullable = false)()
  44. private lazy val weights = AttributeReference("weights", ArrayType(IntegerType), nullable = false)()
  45.  
  46. override lazy val aggBufferAttributes: Seq[AttributeReference] = points :: weights :: Nil
  47.  
  48. override lazy val initialValues: Seq[Expression] = Seq({
  49. val i = UDFUtils.makeIter("price_mbl_initalValues_points")
  50. // 数组长度=点数+2,最后两个元素用于填充startTimeInSeconds和stepInSeconds
  51. GenerateArray(Literal(arraySize), i, Literal(-1, IntegerType))
  52. }, {
  53. val i = UDFUtils.makeIter("price_mbl_initalValues_weights")
  54. GenerateArray(Literal(numRows), i, Literal(0, IntegerType))
  55. })
  56.  
  57. override lazy val updateExpressions: Seq[Expression] = {
  58. val i = UDFUtils.makeIter("price_mbl_updateExpressions")
  59.  
  60. val crawlTime = children(4)
  61. val compPrice = children(5)
  62. val mtPrice = children(6)
  63.  
  64. // 权重不设置的话,默认1
  65. val weight = if (children.length == 8) children(7) else Literal(1)
  66.  
  67. Seq(
  68. DoSeq(
  69. ForStep(pointSpan, 1, i, {
  70. val row = (crawlTime - startTimeInSeconds) / stepInSeconds + i
  71. val pointIndex0 = row * 3
  72. val pointIndex1 = pointIndex0 + 1
  73. val pointIndex2 = pointIndex0 + 2
  74.  
  75. val prevCrawlTime = GetArrayItemWithSize(arraySize, points, pointIndex0)
  76. val prevCompPrice = GetArrayItemWithSize(arraySize, points, pointIndex1)
  77. val prevMtPrice = GetArrayItemWithSize(arraySize, points, pointIndex2)
  78.  
  79. If(pointIndex0 < numPoints &&
  80. crawlTime > startTimeInSeconds && crawlTime < endTimeInSeconds &&
  81. (prevCrawlTime < 0 || compPrice < prevCompPrice || (compPrice === prevCompPrice && mtPrice < prevMtPrice)),
  82. Then(
  83. SetArrayItem(points, pointIndex0, crawlTime),
  84. SetArrayItem(points, pointIndex1, compPrice),
  85. SetArrayItem(points, pointIndex2, mtPrice),
  86. SetArrayItem(weights, row, weight),
  87. points),
  88. Else(
  89. points
  90. ))
  91. }),
  92. points),
  93. weights
  94. )
  95. }
  96.  
  97. override lazy val mergeExpressions: Seq[Expression] = {
  98. val i = UDFUtils.makeIter("price_mbl_mergeExpressions")
  99. Seq(
  100. DoSeq(
  101. ForStep(numPoints, 3, i, {
  102. val leftCrawlTime = GetArrayItemWithSize(arraySize, points.left, i)
  103. val leftCompPrice = GetArrayItemWithSize(arraySize, points.left, i + 1)
  104. val leftMtPrice = GetArrayItemWithSize(arraySize, points.left, i + 2)
  105. val leftWeight = GetArrayItemWithSize(arraySize, points.left, i/3)
  106.  
  107. val rightCrawlTime = GetArrayItemWithSize(arraySize, points.right, i)
  108. val rightCompPrice = GetArrayItemWithSize(arraySize, points.right, i + 1)
  109. val rightMtPrice = GetArrayItemWithSize(arraySize, points.right, i + 2)
  110. val rightWeight = GetArrayItemWithSize(arraySize, points.right, i/3)
  111.  
  112. If(leftCompPrice < rightCompPrice || leftCompPrice === rightCompPrice && leftMtPrice <= rightMtPrice,
  113. Then(
  114. SetArrayItem(points, i, leftCrawlTime),
  115. SetArrayItem(points, i + 1, leftCompPrice),
  116. SetArrayItem(points, i + 2, leftMtPrice),
  117. SetArrayItem(weights, i/3, leftWeight)
  118. ),
  119. Else(
  120. SetArrayItem(points, i, rightCrawlTime),
  121. SetArrayItem(points, i + 1, rightCompPrice),
  122. SetArrayItem(points, i + 2, rightMtPrice),
  123. SetArrayItem(weights, i/3, rightWeight)
  124. )
  125. )
  126. }),
  127. points),
  128. weights
  129. )
  130. }
  131.  
  132. override lazy val evaluateExpression: Expression = {
  133. val i = UDFUtils.makeIter("price_mbl_evaluateExpression")
  134. DoSeq(
  135. // 把startTime和step添加到数据最后两个元素中,为sum时候使用
  136. SetArrayItem(points, Literal(numPoints), Literal(startTimeInSeconds)),
  137. SetArrayItem(points, Literal(numPoints+1), Literal(stepInSeconds)),
  138.  
  139. // points数组里是(抓取时间,竞对价,美团价)
  140. // 需要将其转换成(meet, beat, lose):
  141. // meet=(weight,0,0)
  142. // beat=(0,weight,0)
  143. // lose=(0,0,weight)
  144. ForStep(numPoints, step = 3, i, {
  145. val crawlTime = GetArrayItem(points, i)
  146. val compPrice = GetArrayItem(points, i + 1)
  147. val mtPrice = GetArrayItem(points, i + 2)
  148. val weight = GetArrayItem(weights, i/3)
  149.  
  150. val meet = compPrice === mtPrice
  151. val beat = compPrice > mtPrice
  152. val lose = compPrice < mtPrice
  153.  
  154. If (crawlTime > 0,
  155. If (meet,
  156. DoSeq(
  157. SetArrayItem(points, i + 0, weight),
  158. SetArrayItem(points, i + 1, 0),
  159. SetArrayItem(points, i + 2, 0),
  160. points),
  161.  
  162. If (beat,
  163. DoSeq(
  164. SetArrayItem(points, i + 0, 0),
  165. SetArrayItem(points, i + 1, weight),
  166. SetArrayItem(points, i + 2, 0),
  167. points),
  168.  
  169. If (lose,
  170. DoSeq(
  171. SetArrayItem(points, i + 0, 0),
  172. SetArrayItem(points, i + 1, 0),
  173. SetArrayItem(points, i + 2, weight),
  174. points),
  175. points)
  176. )
  177. ),
  178. points)
  179. }),
  180. points
  181. )
  182. }
  183. }
Add Comment
Please, Sign In to add comment