Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.apache.spark.countSerDe
- import org.apache.spark.internal.Logging
- import org.apache.spark.sql.Row
- import org.apache.spark.sql.catalyst.InternalRow
- import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
- import org.apache.spark.sql.catalyst.util._
- import org.apache.spark.sql.expressions.MutableAggregationBuffer
- import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
- import org.apache.spark.sql.expressions.UserDefinedImperativeAggregator
- import org.apache.spark.sql.types._
- @SQLUserDefinedType(udt = classOf[CountSerDeUDT])
- case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Double)
- class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
- def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
- override def typeName: String = "count-ser-de"
- private[spark] override def asNullable: CountSerDeUDT = this
- def sqlType: DataType = StructType(
- StructField("nSer", IntegerType, false) ::
- StructField("nDeSer", IntegerType, false) ::
- StructField("sum", DoubleType, false) ::
- Nil)
- def serialize(sql: CountSerDeSQL): Any = {
- val row = new GenericInternalRow(3)
- row.setInt(0, 1 + sql.nSer)
- row.setInt(1, sql.nDeSer)
- row.setDouble(2, sql.sum)
- row
- }
- def deserialize(any: Any): CountSerDeSQL = any match {
- case row: InternalRow if (row.numFields == 3) =>
- CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getDouble(2))
- case u => throw new Exception(s"failed to deserialize: $u")
- }
- override def equals(obj: Any): Boolean = {
- obj match {
- case _: CountSerDeUDT => true
- case _ => false
- }
- }
- override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
- }
- case object CountSerDeUDT extends CountSerDeUDT
- case object CountSerDeUDIA extends UserDefinedImperativeAggregator[CountSerDeSQL] {
- import org.apache.spark.unsafe.Platform
- def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
- def resultType: DataType = CountSerDeUDT
- def deterministic: Boolean = false
- def initial: CountSerDeSQL = CountSerDeSQL(0, 0, 0)
- def update(agg: CountSerDeSQL, input: Row): CountSerDeSQL =
- agg.copy(sum = agg.sum + input.getDouble(0))
- def merge(agg1: CountSerDeSQL, agg2: CountSerDeSQL): CountSerDeSQL =
- CountSerDeSQL(agg1.nSer + agg2.nSer, agg1.nDeSer + agg2.nDeSer, agg1.sum + agg2.sum)
- def evaluate(agg: CountSerDeSQL): Any = agg
- def serialize(agg: CountSerDeSQL): Array[Byte] = {
- val CountSerDeSQL(ns, nd, s) = agg
- val byteArray = new Array[Byte](4 + 4 + 8)
- Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET, ns + 1)
- Platform.putInt(byteArray, Platform.BYTE_ARRAY_OFFSET + 4, nd)
- Platform.putDouble(byteArray, Platform.BYTE_ARRAY_OFFSET + 8, s)
- byteArray
- }
- def deserialize(data: Array[Byte]): CountSerDeSQL = {
- val ns = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET)
- val nd = Platform.getInt(data, Platform.BYTE_ARRAY_OFFSET + 4)
- val s = Platform.getDouble(data, Platform.BYTE_ARRAY_OFFSET + 8)
- CountSerDeSQL(ns, nd + 1, s)
- }
- }
- case object CountSerDeUDAF extends UserDefinedAggregateFunction {
- def deterministic: Boolean = false
- def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil)
- def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil)
- def dataType: DataType = CountSerDeUDT
- def initialize(buf: MutableAggregationBuffer): Unit = {
- buf(0) = CountSerDeSQL(0, 0, 0)
- }
- def update(buf: MutableAggregationBuffer, input: Row): Unit = {
- val sql = buf.getAs[CountSerDeSQL](0)
- buf(0) = sql.copy(sum = sql.sum + input.getDouble(0))
- }
- def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
- val sql1 = buf1.getAs[CountSerDeSQL](0)
- val sql2 = buf2.getAs[CountSerDeSQL](0)
- buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer, sql1.sum + sql2.sum)
- }
- def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement