Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /**
- * Created by michael on 9/15/16.
- * self contained reproduction of the aggregateByKey MatchError
- */
- import org.apache.spark.{SparkContext, SparkConf}
- import scala.collection.immutable.HashMap
- import scala.util.Random
- import java.util.UUID.randomUUID
- object mainBugTest {
- def main(args: Array[String]): Unit = {
- // Spark settings
- val sparkConf = new SparkConf()
- sparkConf.setAppName("mainBugTest")
- sparkConf.setMaster("local[*]")
- sparkConf.set("spark.io.compression.codec", "lzf")
- sparkConf.set("spark.speculation", "true")
- sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sparkConf.set("spark.kryo.referenceTracking", "false")
- sparkConf.set("spark.kryoserializer.buffer", "64k")
- sparkConf.set("spark.kryoserializer.buffer.max", "1g")
- sparkConf.registerKryoClasses(Array(classOf[HashMap[String, Double]]))
- val sc = new SparkContext(sparkConf)
- // Other settings
- val tagCount = 200
- val tagPerEntity = 10
- val entityCount = 1000
- val entityRepeat = 200
- // Methods
- def updateWeight(m: HashMap[String, Double], entry: (String, Double)): HashMap[String, Double] = {
- val v = m.getOrElse(entry._1, 0.0)
- m.updated(entry._1, v + entry._2)
- }
- def mergeMaps(m1: HashMap[String, Double], m2: HashMap[String, Double]) =
- m1.merged(m2) { case ((k, v1), (_, v2)) => (k, v1 + v2) }
- // Data creation
- val tags = (1 to tagCount).map(v => randomUUID().toString)
- val entityIds = (1 to entityCount).map(v => randomUUID().toString.toUpperCase)
- val data = Random.shuffle(entityIds.flatMap(g => (1 to entityRepeat).map(v => g).map(a => a -> Random.shuffle(tags).take(tagPerEntity))))
- val rdd = sc.parallelize(data)
- println("# of rdd partitions: " + rdd.partitions.size)
- rdd.mapPartitions(iter => Array(iter.size).iterator, true) foreach println
- // Data manipulation
- val newRdd = rdd.flatMap { case (id, tag) => tag.toSet[String].subsets(2).map(v => id ->(v.mkString(""), 1.0))}.repartition(3)
- // the repartition is needed only to equate the size of each partition not to trigger out of memory exceptions
- println("# of newRdd partitions: " + newRdd.partitions.size)
- newRdd.mapPartitions(iter => Array(iter.size).iterator, true) foreach println
- newRdd.take(10) foreach println
- val result = newRdd.aggregateByKey(new HashMap[String, Double]())(
- { case (m, e) => updateWeight(m, e) },
- { case (m1, m2) => mergeMaps(m1, m2) }
- )
- result.take(10) foreach println
- }
- }
Add Comment
Please, Sign In to add comment