Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.meituan.hotel.oe.snappydata.mbl
- import org.apache.spark.sql.catalyst.dsl.expressions._
- import org.apache.spark.sql.catalyst.expressions._
- import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
- import org.apache.spark.sql.types._
- case class PriceMBL(children: Seq[Expression]) extends DeclarativeAggregate {
- private lazy val startTimeInSeconds: Int = children.head.eval().asInstanceOf[Int]
- private lazy val endTimeInSeconds: Int = children(1).eval().asInstanceOf[Int]
- private lazy val stepInSeconds: Int = children(2).eval().asInstanceOf[Int]
- private lazy val durationInSeconds: Int = children(3).eval().asInstanceOf[Int]
- private final def numRowsReal: Int = (endTimeInSeconds - startTimeInSeconds) / stepInSeconds
- // 行数,设置个最大值,防止OOM
- private final def numRows: Int = if (numRowsReal > 1024 * 3) 1024 * 3 else numRowsReal
- // 每行3个点
- private final def numPoints: Int = numRows * 3
- // 最后两个元素用于填充startTimeInSeconds和stepInSeconds
- private final val arraySize = numPoints + 2
- // 如果duartion小于step,说明有部分点是不需要计算的,目前看并不需要,所以没有实现该逻辑,小于和等于是相同的结果
- private final def pointSpan: Int = if (durationInSeconds <= stepInSeconds) 1 else durationInSeconds / stepInSeconds
- override def inputTypes: Seq[DataType] = Seq(
- IntegerType, // 1: startTimeInSeconds
- IntegerType, // 2: endTimeInSeconds
- IntegerType, // 3: stepInSeconds
- IntegerType, // 4: durationInSeconds
- IntegerType, // 5: crawTime
- IntegerType, // 6: compPrice
- IntegerType // 7: mtPrice
- ) ++ (if (children.length == 8) Seq(IntegerType /* weight */) else Nil) // 8: 第八个参数"权重"可选
- override def dataType: DataType = ArrayType(IntegerType)
- override def nullable: Boolean= false
- private lazy val points = AttributeReference("points", ArrayType(IntegerType), nullable = false)()
- private lazy val weights = AttributeReference("weights", ArrayType(IntegerType), nullable = false)()
- override lazy val aggBufferAttributes: Seq[AttributeReference] = points :: weights :: Nil
- override lazy val initialValues: Seq[Expression] = Seq({
- val i = UDFUtils.makeIter("price_mbl_initalValues_points")
- // 数组长度=点数+2,最后两个元素用于填充startTimeInSeconds和stepInSeconds
- GenerateArray(Literal(arraySize), i, Literal(-1, IntegerType))
- }, {
- val i = UDFUtils.makeIter("price_mbl_initalValues_weights")
- GenerateArray(Literal(numRows), i, Literal(0, IntegerType))
- })
- override lazy val updateExpressions: Seq[Expression] = {
- val i = UDFUtils.makeIter("price_mbl_updateExpressions")
- val crawlTime = children(4)
- val compPrice = children(5)
- val mtPrice = children(6)
- // 权重不设置的话,默认1
- val weight = if (children.length == 8) children(7) else Literal(1)
- Seq(
- DoSeq(
- ForStep(pointSpan, 1, i, {
- val row = (crawlTime - startTimeInSeconds) / stepInSeconds + i
- val pointIndex0 = row * 3
- val pointIndex1 = pointIndex0 + 1
- val pointIndex2 = pointIndex0 + 2
- val prevCrawlTime = GetArrayItemWithSize(arraySize, points, pointIndex0)
- val prevCompPrice = GetArrayItemWithSize(arraySize, points, pointIndex1)
- val prevMtPrice = GetArrayItemWithSize(arraySize, points, pointIndex2)
- If(pointIndex0 < numPoints &&
- crawlTime > startTimeInSeconds && crawlTime < endTimeInSeconds &&
- (prevCrawlTime < 0 || compPrice < prevCompPrice || (compPrice === prevCompPrice && mtPrice < prevMtPrice)),
- Then(
- SetArrayItem(points, pointIndex0, crawlTime),
- SetArrayItem(points, pointIndex1, compPrice),
- SetArrayItem(points, pointIndex2, mtPrice),
- SetArrayItem(weights, row, weight),
- points),
- Else(
- points
- ))
- }),
- points),
- weights
- )
- }
- override lazy val mergeExpressions: Seq[Expression] = {
- val i = UDFUtils.makeIter("price_mbl_mergeExpressions")
- Seq(
- DoSeq(
- ForStep(numPoints, 3, i, {
- val leftCrawlTime = GetArrayItemWithSize(arraySize, points.left, i)
- val leftCompPrice = GetArrayItemWithSize(arraySize, points.left, i + 1)
- val leftMtPrice = GetArrayItemWithSize(arraySize, points.left, i + 2)
- val leftWeight = GetArrayItemWithSize(arraySize, points.left, i/3)
- val rightCrawlTime = GetArrayItemWithSize(arraySize, points.right, i)
- val rightCompPrice = GetArrayItemWithSize(arraySize, points.right, i + 1)
- val rightMtPrice = GetArrayItemWithSize(arraySize, points.right, i + 2)
- val rightWeight = GetArrayItemWithSize(arraySize, points.right, i/3)
- If(leftCompPrice < rightCompPrice || leftCompPrice === rightCompPrice && leftMtPrice <= rightMtPrice,
- Then(
- SetArrayItem(points, i, leftCrawlTime),
- SetArrayItem(points, i + 1, leftCompPrice),
- SetArrayItem(points, i + 2, leftMtPrice),
- SetArrayItem(weights, i/3, leftWeight)
- ),
- Else(
- SetArrayItem(points, i, rightCrawlTime),
- SetArrayItem(points, i + 1, rightCompPrice),
- SetArrayItem(points, i + 2, rightMtPrice),
- SetArrayItem(weights, i/3, rightWeight)
- )
- )
- }),
- points),
- weights
- )
- }
- override lazy val evaluateExpression: Expression = {
- val i = UDFUtils.makeIter("price_mbl_evaluateExpression")
- DoSeq(
- // 把startTime和step添加到数据最后两个元素中,为sum时候使用
- SetArrayItem(points, Literal(numPoints), Literal(startTimeInSeconds)),
- SetArrayItem(points, Literal(numPoints+1), Literal(stepInSeconds)),
- // points数组里是(抓取时间,竞对价,美团价)
- // 需要将其转换成(meet, beat, lose):
- // meet=(weight,0,0)
- // beat=(0,weight,0)
- // lose=(0,0,weight)
- ForStep(numPoints, step = 3, i, {
- val crawlTime = GetArrayItem(points, i)
- val compPrice = GetArrayItem(points, i + 1)
- val mtPrice = GetArrayItem(points, i + 2)
- val weight = GetArrayItem(weights, i/3)
- val meet = compPrice === mtPrice
- val beat = compPrice > mtPrice
- val lose = compPrice < mtPrice
- If (crawlTime > 0,
- If (meet,
- DoSeq(
- SetArrayItem(points, i + 0, weight),
- SetArrayItem(points, i + 1, 0),
- SetArrayItem(points, i + 2, 0),
- points),
- If (beat,
- DoSeq(
- SetArrayItem(points, i + 0, 0),
- SetArrayItem(points, i + 1, weight),
- SetArrayItem(points, i + 2, 0),
- points),
- If (lose,
- DoSeq(
- SetArrayItem(points, i + 0, 0),
- SetArrayItem(points, i + 1, 0),
- SetArrayItem(points, i + 2, weight),
- points),
- points)
- )
- ),
- points)
- }),
- points
- )
- }
- }
Add Comment
Please, Sign In to add comment