Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package sampling
- import org.apache.spark.sql.Row
- import org.apache.spark.rdd.RDD
- import org.apache.spark.sql.DataFrame
- import org.apache.spark.sql.functions._
- import scala.util.Random
- object Sampler extends Serializable {
- //I pulled this out of thin air
- val sizeOfRowBytes = 10
- val percentile = 1.96//W/e I made this up
- val get_all_QCS = () => {
- val QCS1 = ("l_shipdate l_returnflag l_linestatus" -> Seq(1))
- val QCS2 = ("l_orderkey l_shipdate" -> Seq(3))
- val QCS3 = ("l_orderkey l_suppkey" -> Seq(5))
- val QCS5 = ("l_shipdate l_discount l_quantity" -> Seq(6))
- //val QCS6 = ("l_extendedprice l_shipdate l_suppkey l_orderkey" -> Seq(7))
- val QCS7 = ("l_suppkey l_partkey l_orderkey" -> Seq(9))
- val QCS8 = ("l_returnflag l_orderkey" -> Seq(10))
- val QCS9 = ("l_orderkey l_shipmode l_commitdate l_sh ipdate l_receiptdate" -> Seq(12))
- val QCS10 = ("l_quantity l_partkey" -> Seq(17))
- val QCS11 = ("l_orderkey l_quantity" -> Seq(18))
- val QCS12 = ("l_partkey l_quantity l_shipmode l_shipinstruct" -> Seq(19))
- val QCS13 = ("l_partkey l_suppkey l_shipdate l_quantity" -> Seq(20))
- val m = Map(QCS1, QCS2, QCS3, QCS5)
- m
- }
- def sample(lineitem: DataFrame, storageBudgetBytes: Long, e: Double, ci: Double): (List[RDD[_]], _) = {
- // TODO: implement
- var found_Ks = Seq[(Long, Long, Seq[Int], String)]()
- val all_QCS = get_all_QCS()
- all_QCS.foreach(x => {
- val qcs = (x._1).split(" ")
- val K = find_K(lineitem, qcs, e, ci)
- if (K == null) {
- //TODO:
- }
- else {
- found_Ks :+= (K._1.toLong, K._2, x._2, x._1)
- }
- })
- val samplesToCreate = which_samples_to_generate(found_Ks, storageBudgetBytes)
- generate_samples(lineitem, samplesToCreate)
- }
- def generate_samples(lineitem: DataFrame,Ks: Seq[(Long, Seq[Int], String)]):(List[RDD[_]], _) = {
- var temp = List[RDD[Row]]()
- val rdd = lineitem.rdd
- val sampleStratum = (rows: Seq[Row], K:Long) => {
- var temp: Seq[Row] = Seq[Row]()
- val N = rows.length.toLong
- if (N <= K) {
- temp = rows
- }
- else {
- val sample = Random.shuffle(rows).take(K.toInt)
- temp = rows
- }
- var p = 1.0
- if (K < N ) {
- p = K / N
- }
- val added_column = temp.map(row => {
- Row.fromSeq(row.toSeq + p.toString)
- })
- added_column
- }
- val l = Ks.map(samp => {
- val qcs = samp._3.split(" ")
- val K = samp._1
- val mapped = rdd.map(row => {
- val key = qcs.reduce(_ + " " +row.getAs(_).toString)
- (key, Seq(row))
- })
- val reduced = mapped.reduceByKey((s1: Seq[Row], s2:Seq[Row]) => s1 ++ s2)
- //This is the value we return where all of the sequences have been flattened
- reduced.flatMap(tup => sampleStratum(tup._2, K))
- })
- var i = 0
- var queryMap = Map[Int, Int]()
- Ks.foreach(samp => {
- val queries = samp._2
- queries.foreach(queryNum => {
- queryMap += (queryNum -> i)
- })
- i += 1
- })
- (l.toList, queryMap)
- }
- def which_samples_to_generate(K_seq: Seq[(Long, Long, Seq[Int], String)], budget: Long):Seq[(Long, Seq[Int], String)] = {
- K_seq.sortBy(x => (x._3).length)
- var currentBudget = budget
- var temp: Seq[(Long, Seq[Int], String)] = Seq[(Long, Seq[Int], String)]()
- K_seq.foreach(x => {
- val cost = x._1
- val K = x._2
- val queries = x._3
- val qcs = x._4
- if (currentBudget - cost >= 0) {
- currentBudget -= cost
- temp :+= (K, queries, qcs)
- }
- })
- temp
- }
- val get_QCS_data = (lineitem: DataFrame,col1: String, aggAttr: String, qcs: Array[String]) => {
- val strata = lineitem.groupBy(col1, qcs:_*).agg(count("*") as "count",stddev(aggAttr) as "stddev", sum(aggAttr) as "sum").rdd
- val countAllStrata = strata.count()
- (strata, countAllStrata)
- }
- val guess_K = (rdd: RDD[Row], K: Double, e: Double, ci: Double, count: Long) => {
- val it = rdd.toLocalIterator
- val vals = it.foldLeft((0.0, 0.toLong, 0.toLong))((t: (Double, Long, Long), row: Row) => {
- val Nh = row.getAs("count").asInstanceOf[Long]
- var Sh: Double = 0.0
- if (Nh == 1) {
- Sh = 0.0
- }
- else {
- Sh = row.getAs("stddev").asInstanceOf[Double]
- }
- val m = row.getAs("sum").asInstanceOf[java.math.BigDecimal].longValue()
- var sum = t._1
- val mean = t._2 + m
- var count = t._3
- if (Nh < K) {
- count += Nh
- sum += ((1.toDouble / Nh) * Nh * Nh * Sh * Sh)
- }
- else {
- count += K.toLong
- sum += ((1.toDouble / K) * Nh * Nh * Sh * Sh)
- }
- (sum, mean, count)
- })
- val mean = vals._2
- val sum = vals._1
- val count = vals._3
- val Za = percentile//Replace this later
- if (mean == 0) {
- (false, 0.toLong)
- }
- else {
- val rel_err = (Za * math.sqrt(sum)) / mean
- print("Relative error", rel_err, "\n")
- //Return true when the error is sufficiently small enough
- if (rel_err <= e) {
- (true, count)
- }
- else {
- //Return false when the error is not sufficiently small enough
- (false, 0.toLong)
- }
- }
- }
- val const_column = "l_extendedprice"
- //Will probably have to put a bigger number here sooner or later
- val K_INCREMENTS = 1.5
- val INITIAL_K = 100.0
- val find_K = (lineitem: DataFrame, qcs: Array[String], e: Double, ci: Double) => {
- val qcs_data = get_QCS_data(lineitem, qcs.head, const_column, qcs.tail)
- val groups = qcs_data._1
- val population = qcs_data._2
- var K = INITIAL_K
- var found = false
- var count:Long = 0.toLong
- while (K < population && !found) {
- K *= K_INCREMENTS
- val results = guess_K(groups,K, e, ci, population)
- val isGoodK = results._1
- if (isGoodK) {
- found = true
- count = results._2
- }
- }
- var empty:(Double, Long) = null
- if (!found) {
- empty
- }
- else {
- (K, count)
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement