Advertisement
Guest User

Untitled

a guest
Mar 24th, 2014
153
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 1.50 KB | None | 0 0
  1. import java.util.Random
  2. import org.apache.spark.SparkContext
  3. import org.apache.spark.util.Vector
  4. import org.apache.spark.SparkContext._
  5.  
  6. def parseVector(line: String): Vector = {
  7.     new Vector(line.split(' ').map(_.toDouble))
  8.   }
  9.  
  10.   def closestPoint(p: Vector, centers: Array[Vector]): Int = {
  11.     var index = 0
  12.     var bestIndex = 0
  13.     var closest = Double.PositiveInfinity
  14.  
  15.     for (i <- 0 until centers.length) {
  16.       val tempDist = p.squaredDist(centers(i))
  17.       if (tempDist < closest) {
  18.         closest = tempDist
  19.         bestIndex = i
  20.       }
  21.     }
  22.  
  23.     bestIndex
  24.   }
  25.  
  26.     val lines = sc.textFile("/storage/lm/kmeans_50M.txt")
  27.     val data = lines.map(parseVector _).cache()
  28.     val K = 500000
  29.     val convergeDist = 0.1
  30.  
  31.     val kPoints = data.takeSample(withReplacement = false, K, 42).toArray
  32.     var tempDist = 1.0
  33.  
  34.     while(tempDist > convergeDist) {
  35.       val closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
  36.      
  37.       val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
  38.      
  39.       val newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
  40.      
  41.       tempDist = 0.0
  42.       for (i <- 0 until K) {
  43.         tempDist += kPoints(i).squaredDist(newPoints(i))
  44.       }
  45.      
  46.       for (newP <- newPoints) {
  47.         kPoints(newP._1) = newP._2
  48.       }
  49.       println("Finished iteration (delta = " + tempDist + ")")
  50.     }
  51.  
  52.     println("Final centers:")
  53.     kPoints.foreach(println)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement