Advertisement
Guest User

Untitled

a guest
Jul 18th, 2019
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.89 KB | None | 0 0
  1. package org.apache.spark.countSerDe
  2.  
  3. import org.apache.spark.internal.Logging
  4. import org.apache.spark.sql.Row
  5. import org.apache.spark.sql.catalyst.InternalRow
  6. import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
  7. import org.apache.spark.sql.catalyst.util._
  8. import org.apache.spark.sql.expressions.MutableAggregationBuffer
  9. import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
  10. import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator
  11. import org.apache.spark.sql.types._
  12.  
  13.  
  14. @SQLUserDefinedType(udt = classOf[CountSerDeUDT])
  15. case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Double)
  16.  
  17. class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
  18. def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
  19.  
  20. override def typeName: String = "count-ser-de"
  21.  
  22. private[spark] override def asNullable: CountSerDeUDT = this
  23.  
  24. def sqlType: DataType = StructType(
  25. StructField("nSer", IntegerType, false) ::
  26. StructField("nDeSer", IntegerType, false) ::
  27. StructField("sum", DoubleType, false) ::
  28. Nil)
  29.  
  30. def serialize(sql: CountSerDeSQL): Any = {
  31. val row = new GenericInternalRow(3)
  32. row.setInt(0, 1 + sql.nSer)
  33. row.setInt(1, sql.nDeSer)
  34. row.setDouble(2, sql.sum)
  35. row
  36. }
  37.  
  38. def deserialize(any: Any): CountSerDeSQL = any match {
  39. case row: InternalRow if (row.numFields == 3) =>
  40. CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getDouble(2))
  41. case u => throw new Exception(s"failed to deserialize: $u")
  42. }
  43.  
  44. override def equals(obj: Any): Boolean = {
  45. obj match {
  46. case _: CountSerDeUDT => true
  47. case _ => false
  48. }
  49. }
  50.  
  51. override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
  52. }
  53.  
  54. case object CountSerDeUDT extends CountSerDeUDT
  55.  
  56. case object CountSerDeUDIA extends UserDefinedImperativeAggregator[CountSerDeSQL] {
  57. import org.apache.spark.unsafe.Platform
  58.  
  59. def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
  60. def resultType: DataType = CountSerDeUDT
  61. def deterministic: Boolean = false
  62. def initial: CountSerDeSQL = CountSerDeSQL(0, 0, 0)
  63. def update(agg: CountSerDeSQL, input: Row): CountSerDeSQL =
  64. agg.copy(sum = agg.sum + input.getDouble(0))
  65. def merge(agg1: CountSerDeSQL, agg2: CountSerDeSQL): CountSerDeSQL =
  66. CountSerDeSQL(agg1.nSer + agg2.nSer, agg1.nDeSer + agg2.nDeSer, agg1.sum + agg2.sum)
  67. def evaluate(agg: CountSerDeSQL): Any = agg
  68.  
  69. def serialize(agg: CountSerDeSQL): Array[Byte] = {
  70. val CountSerDeSQL(ns, nd, s) = agg
  71. val byteArray = new Array[Byte](4 + 4 + 8)
  72. Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET, ns + 1)
  73. Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET + 4, nd)
  74. Platform.putDouble(byteArray, Platform.BYTE_ARRAY_OFFSET + 8, s)
  75. byteArray
  76. }
  77.  
  78. def deserialize(data: Array[Byte]): CountSerDeSQL = {
  79. val ns = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET)
  80. val nd = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET + 4)
  81. val s = Platform.getDouble(data, Platform.BYTE_ARRAY_OFFSET + 8)
  82. CountSerDeSQL(ns, nd + 1, s)
  83. }
  84. }
  85.  
  86. case object CountSerDeUDAF extends UserDefinedAggregateFunction {
  87. def deterministic: Boolean = false
  88.  
  89. def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
  90.  
  91. def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil)
  92.  
  93. def dataType: DataType = CountSerDeUDT
  94.  
  95. def initialize(buf: MutableAggregationBuffer): Unit = {
  96. buf(0) = CountSerDeSQL(0, 0, 0)
  97. }
  98.  
  99. def update(buf: MutableAggregationBuffer, input: Row): Unit = {
  100. val sql = buf.getAs[CountSerDeSQL](0)
  101. buf(0) = sql.copy(sum = sql.sum + input.getDouble(0))
  102. }
  103.  
  104. def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
  105. val sql1 = buf1.getAs[CountSerDeSQL](0)
  106. val sql2 = buf2.getAs[CountSerDeSQL](0)
  107. buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer, sql1.sum + sql2.sum)
  108. }
  109.  
  110. def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0)
  111. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement