Advertisement
jules0707

Kmeans

Apr 20th, 2017
328
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 3.80 KB | None | 0 0
  1. package kmeans
  2.  
  3. import scala.annotation.tailrec
  4.  
  5. class KMeans {
  6.  
  7.   def generatePoints(k: Int, num: Int): Seq[Point] = {
  8.     val randx = new Random(1)
  9.     val randy = new Random(3)
  10.     val randz = new Random(5)
  11.     (0 until num)
  12.       .map({ i =>
  13.         val x = ((i + 1) % k) * 1.0 / k + randx.nextDouble() * 0.5
  14.         val y = ((i + 5) % k) * 1.0 / k + randy.nextDouble() * 0.5
  15.         val z = ((i + 7) % k) * 1.0 / k + randz.nextDouble() * 0.5
  16.         new Point(x, y, z)
  17.       }).to[mutable.ArrayBuffer]
  18.   }
  19.  
  20.   def initializeMeans(k: Int, points: Seq[Point]): Seq[Point] = {
  21.     val rand = new Random(7)
  22.     (0 until k).map(_ => points(rand.nextInt(points.length))).to[mutable.ArrayBuffer]
  23.   }
  24.  
  25.   def findClosest(p: Point, means: GenSeq[Point]): Point = {
  26.     assert(means.size > 0)
  27.     var minDistance = p.squareDistance(means(0))
  28.     var closest = means(0)
  29.     var i = 1
  30.     while (i < means.length) {
  31.       val distance = p.squareDistance(means(i))
  32.       if (distance < minDistance) {
  33.         minDistance = distance
  34.         closest = means(i)
  35.       }
  36.       i += 1
  37.     }
  38.     closest
  39.   }
  40.  
  41.   // we ensure that all the means are in the GenMap by mapping each one of them to the empty sequence
  42.   // thanks dnc 1994
  43.   def classify(points: GenSeq[Point], means: GenSeq[Point]): GenMap[Point, GenSeq[Point]] = {
  44.     means.map((_,GenSeq())).toMap ++ points.groupBy {findClosest(_, means)}
  45.   }
  46.  
  47.   def findAverage(oldMean: Point, points: GenSeq[Point]): Point = if (points.length == 0) oldMean else {
  48.     var x = 0.0
  49.     var y = 0.0
  50.     var z = 0.0
  51.     points.seq.foreach { p =>
  52.       x += p.x
  53.       y += p.y
  54.       z += p.z
  55.     }
  56.     new Point(x / points.length, y / points.length, z / points.length)
  57.   }
  58.  
  59.   def update(classified: GenMap[Point, GenSeq[Point]], oldMeans: GenSeq[Point]): GenSeq[Point] = {
  60.     for(m <- oldMeans) yield findAverage(m, classified(m))
  61.     //  dnc 1994
  62.     //oldMeans.map(mean => findAverage(mean, classified(mean)))
  63.   }
  64.  
  65.   def converged(eta: Double)(oldMeans: GenSeq[Point], newMeans: GenSeq[Point]): Boolean = {
  66.     oldMeans.forall(m => (m.squareDistance(newMeans(oldMeans.indexOf(m))) <= eta))
  67.   }
  68.  
  69.   @tailrec
  70.   final def kMeans(points: GenSeq[Point], means: GenSeq[Point], eta: Double): GenSeq[Point] = {
  71.     if (! converged(eta)(means, update(classify(points, means),means)))
  72.         kMeans(points,update(classify(points, means),means), eta)
  73.         // we run at least one update
  74.     else update(classify(points, means),means)
  75.   }
  76. }
  77.  
  78. /** Describes one point in three-dimensional space.
  79.  *
  80.  *  Note: deliberately uses reference equality.
  81.  */
  82. class Point(val x: Double, val y: Double, val z: Double) {
  83.   private def square(v: Double): Double = v * v
  84.   def squareDistance(that: Point): Double = {
  85.     square(that.x - x)  + square(that.y - y) + square(that.z - z)
  86.   }
  87.   private def round(v: Double): Double = (v * 100).toInt / 100.0
  88.   override def toString = s"(${round(x)}, ${round(y)}, ${round(z)})"
  89. }
  90.  
  91.  
  92. object KMeansRunner {
  93.  
  94.   val standardConfig = config(
  95.     Key.exec.minWarmupRuns -> 20,
  96.     Key.exec.maxWarmupRuns -> 40,
  97.     Key.exec.benchRuns -> 25,
  98.     Key.verbose -> true
  99.   ) withWarmer(new Warmer.Default)
  100.  
  101.   def main(args: Array[String]) {
  102.     val kMeans = new KMeans()
  103.  
  104.     val numPoints = 500000
  105.     val eta = 0.01
  106.     val k = 32
  107.     val points = kMeans.generatePoints(k, numPoints)
  108.     val means = kMeans.initializeMeans(k, points)
  109.  
  110.     val seqtime = standardConfig measure {
  111.       kMeans.kMeans(points, means, eta)
  112.     }
  113.     println(s"sequential time: $seqtime ms")
  114.  
  115.     val partime = standardConfig measure {
  116.       val parPoints = points.par
  117.       val parMeans = means.par
  118.       kMeans.kMeans(parPoints, parMeans, eta)
  119.     }
  120.     println(s"parallel time: $partime ms")
  121.     println(s"speedup: ${seqtime / partime}")
  122.   }
  123.  
  124. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement