Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.util.Random
- import org.apache.spark.SparkContext
- import org.apache.spark.util.Vector
- import org.apache.spark.SparkContext._
- def parseVector(line: String): Vector = {
- new Vector(line.split(' ').map(_.toDouble))
- }
- def closestPoint(p: Vector, centers: Array[Vector]): Int = {
- var index = 0
- var bestIndex = 0
- var closest = Double.PositiveInfinity
- for (i <- 0 until centers.length) {
- val tempDist = p.squaredDist(centers(i))
- if (tempDist < closest) {
- closest = tempDist
- bestIndex = i
- }
- }
- bestIndex
- }
- val lines = sc.textFile("/storage/lm/kmeans_50M.txt")
- val data = lines.map(parseVector _).cache()
- val K = 500000
- val convergeDist = 0.1
- val kPoints = data.takeSample(withReplacement = false, K, 42).toArray
- var tempDist = 1.0
- while(tempDist > convergeDist) {
- val closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
- val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
- val newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
- tempDist = 0.0
- for (i <- 0 until K) {
- tempDist += kPoints(i).squaredDist(newPoints(i))
- }
- for (newP <- newPoints) {
- kPoints(newP._1) = newP._2
- }
- println("Finished iteration (delta = " + tempDist + ")")
- }
- println("Final centers:")
- kPoints.foreach(println)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement