Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.apache.spark.SparkContext
- import org.apache.spark.SparkConf
- import org.apache.log4j.Logger
- import org.apache.log4j.Level
- import org.apache.spark.rdd.RDD
- import scala.tools.nsc.matching.Matrix
- import scala.util.control.Breaks
- /**
- * Created by Alexey on 25.04.2016.
- */
- object MultMatrix {
- def main(args: Array[String]) {
- Logger.getLogger("org").setLevel(Level.OFF)
- Logger.getLogger("akka").setLevel(Level.OFF)
- if (args.length < 1) {
- System.err.println("Usage: SparkGrep <host>")
- System.exit(1)
- }
- val conf = new SparkConf().setAppName("SparkGrep").setMaster(args(0))
- val sc = new SparkContext(conf)
- val m2: List[List[Int]] =
- List(
- List(1, 2, 3, 1),
- List(2, 1, 2, 1),
- List(1, 2, 1, 1))
- val m1: List[List[Int]] =
- List(
- List(5, 2, 3),
- List(2, 6, 2),
- List(7, 2, 8),
- List(1, 2, 8)
- )
- val m2Prepared = sc.parallelize(0.to(m2.length - 1).map(i => 0.to(m2(0).length - 1).map(j => (i, j) -> m2(i)(j))).flatMap(q => q))
- val m1Prepared = sc.parallelize(0.to(m1.length - 1).map(i => 0.to(m1(0).length - 1).map(j => (i, j) -> m1(i)(j))).flatMap(q => q))
- val rows = m1Prepared.groupBy(q => q._1._1).sortBy(q => q._1).map(q => q._1 -> q._2.map(q => q._2))
- val columns = m2Prepared.groupBy(q => q._1._2).sortBy(q => q._1).map(q => q._1 -> q._2.map(q => q._2))
- val t1 = rows.collect()
- val t2 = columns.collect()
- val cart = rows.cartesian(columns)
- //val maetrixWithInd = 0.to(matrix.length)
- val x = sc.parallelize(List("spark rdd example", "sample example"), 2)
- val y = x.flatMap(x => x.split(" "))
- val r = cart.map(q => (q._1._1, q._2._1) -> q._1._2.zip(q._2._2).map(w => w._1 * w._2).reduce((q, w) => q + w))
- val result = r.collect().toMap
- //val result = zipped.map(q=>q._1+q._2).reduce()
- val resultRowsCount = m1.length
- val resultColumnsCount = m2(0).length
- for(i <- 0 to resultColumnsCount*resultRowsCount-1)
- {
- val col = i%resultColumnsCount
- val row = i/resultColumnsCount
- print(result.get(row,col).get + "\t")
- if (col==resultColumnsCount-1)
- println()
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement