Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import scala.collection.mutable.Map
- import org.apache.spark.sql.expressions.Aggregator
- import org.apache.spark.sql.Encoder
- import org.apache.spark.sql.Encoders
- import spark.implicits._
- import org.apache.spark.sql.types._
- case class Span(
- ref_name: String,
- bc: String,
- beg: Int,
- end: Int,
- read_count: Int)
- val spanSchema = StructType(
- Array(
- StructField("ref_name", StringType, true),
- StructField("bc", StringType, true),
- StructField("beg", IntegerType, true),
- StructField("end", IntegerType, true),
- StructField("read_count", IntegerType, true)
- )
- )
- object CalcBreakPoints extends Aggregator[Span, Map[Int, Int], Array[Int]] {
- // Reduce an array of spans to coverage, then to break points
- // A zero value for this aggregation. Should satisfy the property that any b + zero = b
- def zero: Map[Int, Int] = Map[Int, Int]()
- // Combine two values to produce a new value. For performance, the function
- // may modify `buffer` and return it instead of constructing a new object
- def reduce(buffer: Map[Int, Int], span: Span): Map[Int, Int] = {
- (span.beg until span.end).foreach(
- i => buffer += (i -> (buffer.getOrElse[Int](i, 0) + 1)))
- buffer
- }
- // Merge two intermediate values
- def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = {
- b2.foreach {
- case (key, value) => b1 += (key -> (value + b1.getOrElse[Int](key, 0)))
- }
- b1
- }
- // Transform the output of the reduction, convert to BreakPoint
- def finish(coverage: Map[Int, Int]): Array[Int] = {
- val cov_cutoff = 20;
- val f = (i: Int) => if (i >= cov_cutoff) 1 else 0
- val coords = coverage.keys.toArray.sorted;
- val bp = coords.slice(1, coords.length).map(
- c => {
- val current = f(coverage(c))
- val previous_step = f(coverage.getOrElse(c - 1, 0))
- (c, current - previous_step)
- })
- .filter { case(c, d) => d != 0}
- .map {case (c, d) => c}
- // val qualified = qualified.slice(1, qualified.length).map {
- // case (c, b) =>
- // c => if (coverage(c) >= read_count_cutoff) (c, 1) else (c, 0))
- // val diff = coords.slice(1, coords.length).map(c => (c, (reduction(c) - reduction.getOrElse(c - 1, 0))))
- // val bp = diff.filter {case (c, d) => d != 0} map {case (c, d) => c}
- bp
- }
- // Specifies the Encoder for the intermediate value type
- def bufferEncoder: Encoder[Map[Int, Int]] = Encoders.kryo
- // Specifies the Encoder for the final output value type
- def outputEncoder: Encoder[Array[Int]] = Encoders.kryo
- }
- val ds = spark.read.option("sep", "\t").schema(spanSchema).csv("/projects/btl/zxue/assembly_correction/celegans/toy_cov.csv").as[Span]
- val cc = CalcBreakPoints.toColumn.name("bp")
- val res = ds.groupByKey(a => a.ref_name).agg(cc)
- res.write.format("parquet").save("./lele.parquet")
Add Comment
Please, Sign In to add comment