Advertisement
Guest User

Untitled

a guest
May 16th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 6.06 KB | None | 0 0
  1. package sampling
  2.  
  3. import org.apache.spark.sql.Row
  4. import org.apache.spark.rdd.RDD
  5. import org.apache.spark.sql.DataFrame
  6. import org.apache.spark.sql.functions._
  7.  
  8. import scala.util.Random
  9.  
  10. object Sampler extends Serializable {
  11.   //I pulled this out of thin air
  12.   val sizeOfRowBytes = 10
  13.  
  14.   val percentile = 1.96//W/e I made this up
  15.  
  16.   val get_all_QCS = () => {
  17.     val QCS1 = ("l_shipdate l_returnflag l_linestatus" -> Seq(1))
  18.     val QCS2 = ("l_orderkey l_shipdate" -> Seq(3))
  19.     val QCS3 = ("l_orderkey l_suppkey" -> Seq(5))
  20.     val QCS5 = ("l_shipdate l_discount l_quantity" -> Seq(6))
  21.     //val QCS6 = ("l_extendedprice l_shipdate l_suppkey l_orderkey" -> Seq(7))
  22.     val QCS7 = ("l_suppkey l_partkey l_orderkey" -> Seq(9))
  23.     val QCS8 = ("l_returnflag l_orderkey" -> Seq(10))
  24.     val QCS9 = ("l_orderkey l_shipmode l_commitdate l_sh  ipdate l_receiptdate" -> Seq(12))
  25.     val QCS10 = ("l_quantity l_partkey" -> Seq(17))
  26.     val QCS11 = ("l_orderkey l_quantity" -> Seq(18))
  27.     val QCS12 = ("l_partkey l_quantity l_shipmode l_shipinstruct" -> Seq(19))
  28.     val QCS13 = ("l_partkey l_suppkey l_shipdate l_quantity" -> Seq(20))
  29.     val m = Map(QCS1, QCS2, QCS3, QCS5)
  30.     m
  31.   }
  32.  
  33.  
  34.   def sample(lineitem: DataFrame, storageBudgetBytes: Long, e: Double, ci: Double): (List[RDD[_]], _) = {
  35.     // TODO: implement
  36.     var found_Ks = Seq[(Long, Long, Seq[Int], String)]()
  37.     val all_QCS = get_all_QCS()
  38.     all_QCS.foreach(x => {
  39.       val qcs = (x._1).split(" ")
  40.       val K = find_K(lineitem, qcs, e, ci)
  41.       if (K == null) {
  42.         //TODO:
  43.       }
  44.       else {
  45.         found_Ks :+= (K._1.toLong, K._2, x._2, x._1)
  46.       }
  47.     })
  48.  
  49.     val samplesToCreate = which_samples_to_generate(found_Ks, storageBudgetBytes)
  50.     generate_samples(lineitem, samplesToCreate)
  51.   }
  52.  
  53.   def generate_samples(lineitem: DataFrame,Ks: Seq[(Long, Seq[Int], String)]):(List[RDD[_]], _) = {
  54.     var temp = List[RDD[Row]]()
  55.     val rdd = lineitem.rdd
  56.     val sampleStratum = (rows: Seq[Row], K:Long) => {
  57.       var temp: Seq[Row] = Seq[Row]()
  58.       val N = rows.length.toLong
  59.       if (N <= K) {
  60.         temp = rows
  61.       }
  62.       else {
  63.         val sample = Random.shuffle(rows).take(K.toInt)
  64.         temp = rows
  65.       }
  66.       var p = 1.0
  67.       if (K < N ) {
  68.         p = K / N
  69.       }
  70.       val added_column = temp.map(row => {
  71.         Row.fromSeq(row.toSeq + p.toString)
  72.       })
  73.       added_column
  74.     }
  75.     val l = Ks.map(samp => {
  76.         val qcs = samp._3.split(" ")
  77.         val K = samp._1
  78.         val mapped = rdd.map(row => {
  79.           val key = qcs.reduce(_ + " " +row.getAs(_).toString)
  80.           (key, Seq(row))
  81.         })
  82.         val reduced = mapped.reduceByKey((s1: Seq[Row], s2:Seq[Row]) => s1 ++ s2)
  83.         //This is the value we return where all of the sequences have been flattened
  84.         reduced.flatMap(tup => sampleStratum(tup._2, K))
  85.     })
  86.     var i = 0
  87.     var queryMap = Map[Int, Int]()
  88.     Ks.foreach(samp => {
  89.       val queries = samp._2
  90.       queries.foreach(queryNum => {
  91.         queryMap += (queryNum -> i)
  92.       })
  93.       i += 1
  94.     })
  95.     (l.toList, queryMap)
  96.   }
  97.  
  98.   def which_samples_to_generate(K_seq: Seq[(Long, Long, Seq[Int], String)], budget: Long):Seq[(Long, Seq[Int], String)] = {
  99.     K_seq.sortBy(x => (x._3).length)
  100.     var currentBudget = budget
  101.     var temp: Seq[(Long, Seq[Int], String)] = Seq[(Long, Seq[Int], String)]()
  102.     K_seq.foreach(x => {
  103.       val cost = x._1
  104.       val K = x._2
  105.       val queries = x._3
  106.       val qcs = x._4
  107.       if (currentBudget - cost >= 0) {
  108.         currentBudget -= cost
  109.         temp :+=  (K, queries, qcs)
  110.       }
  111.     })
  112.     temp
  113.   }
  114.  
  115.   val get_QCS_data = (lineitem: DataFrame,col1: String, aggAttr: String, qcs: Array[String]) => {
  116.     val strata = lineitem.groupBy(col1, qcs:_*).agg(count("*") as "count",stddev(aggAttr) as "stddev", sum(aggAttr) as "sum").rdd
  117.     val countAllStrata = strata.count()
  118.     (strata, countAllStrata)
  119.   }
  120.  
  121.   val guess_K = (rdd: RDD[Row], K: Double, e: Double, ci: Double, count: Long) => {
  122.     val it = rdd.toLocalIterator
  123.     val vals = it.foldLeft((0.0, 0.toLong, 0.toLong))((t: (Double, Long, Long), row: Row) => {
  124.       val Nh = row.getAs("count").asInstanceOf[Long]
  125.       var Sh: Double = 0.0
  126.       if (Nh == 1) {
  127.         Sh = 0.0
  128.       }
  129.       else {
  130.         Sh = row.getAs("stddev").asInstanceOf[Double]
  131.       }
  132.       val m = row.getAs("sum").asInstanceOf[java.math.BigDecimal].longValue()
  133.       var sum = t._1
  134.       val mean = t._2 + m
  135.       var count = t._3
  136.       if (Nh < K) {
  137.         count += Nh
  138.         sum += ((1.toDouble / Nh) * Nh * Nh * Sh * Sh)
  139.       }
  140.       else {
  141.         count += K.toLong
  142.         sum += ((1.toDouble / K) * Nh * Nh * Sh * Sh)
  143.       }
  144.       (sum, mean, count)
  145.     })
  146.     val mean = vals._2
  147.     val sum = vals._1
  148.     val count = vals._3
  149.     val Za = percentile//Replace this later
  150.     if (mean == 0) {
  151.       (false, 0.toLong)
  152.     }
  153.     else {
  154.       val rel_err = (Za * math.sqrt(sum)) / mean
  155.       print("Relative error", rel_err, "\n")
  156.       //Return true when the error is sufficiently small enough
  157.       if (rel_err <= e) {
  158.         (true, count)
  159.       }
  160.       else {
  161.         //Return false when the error is not sufficiently small enough
  162.         (false, 0.toLong)
  163.       }
  164.     }
  165.   }
  166.  
  167.  
  168.  
  169.  
  170.   val const_column = "l_extendedprice"
  171.   //Will probably have to put a bigger number here sooner or later
  172.   val K_INCREMENTS = 1.5
  173.  
  174.   val INITIAL_K = 100.0
  175.  
  176.   val find_K = (lineitem: DataFrame, qcs: Array[String], e: Double, ci: Double) => {
  177.     val qcs_data = get_QCS_data(lineitem, qcs.head, const_column,  qcs.tail)
  178.     val groups = qcs_data._1
  179.     val population = qcs_data._2
  180.     var K = INITIAL_K
  181.     var found = false
  182.     var count:Long = 0.toLong
  183.     while (K < population && !found) {
  184.       K *= K_INCREMENTS
  185.       val results = guess_K(groups,K, e, ci, population)
  186.       val isGoodK = results._1
  187.       if (isGoodK) {
  188.         found = true
  189.         count = results._2
  190.       }
  191.     }
  192.     var empty:(Double, Long) = null
  193.     if (!found) {
  194.       empty
  195.     }
  196.     else {
  197.       (K, count)
  198.     }
  199.   }
  200. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement