Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.sql.{DriverManager, ResultSet}
- import scala.annotation.tailrec
- import scala.collection._;
- import scala.collection.mutable.ListBuffer
- import util.Random
- import org.apache.spark.util.Vector
- class Cluster(val center: Vector, val points: ListBuffer[Vector] = new ListBuffer[Vector])
- object SimpleKMeans {
- def filledData(): immutable.HashMap[String, Vector] = { ...
- new immutable.HashMap[String, Vector]() ++ result
- }
- def randomK(k: Int = 3, values: Iterable[Vector]) = Random.shuffle(values).take(k)
- def mean(points: Iterable[Vector], length: Int) = {
- println("In mean func:")
- println("points: " + points)
- println("length: " + length)
- val sum = points.reduce((x: Vector, y: Vector) => x += y)
- println("sum: " + sum)
- println("result: " + sum / length)
- println("*" * 42)
- sum / length
- }
- // @tailrec
- def kMeans(data: immutable.HashMap[String, Vector], oldClusters: Iterable[Cluster]): Void = {
- println(data.values)
- // val oldPoints = for (c <- oldClusters) yield c.points
- println("Input centers:")
- for (c <- oldClusters) {
- println(c.center)
- }
- println("-" * 42)
- val clusters = for (c <- oldClusters) yield new Cluster(c.center)
- println("Before choosing closest cluster:")
- clusters.map(x => println(x.points))
- for (i <- data.values) {
- // println("\n" + i + "\n")
- def closestCluster(prev: Cluster, next: Cluster) = {
- if ((prev.center dist i) > (next.center dist i))
- next
- else
- prev
- }
- // println()
- // println(i)
- // println()
- clusters.reduce(closestCluster).points += i
- }
- println("After choosing closest cluster:")
- clusters.map(x => println(x.points))
- val result = for (c <- clusters) yield mean(c.points.toList, c.points.toList.length)
- println("New centers:")
- println(result)
- println("-" * 42)
- println("New cluster:")
- kMeans(data, {
- for (r <- result)
- yield {
- println("---- " + r);
- new Cluster(r)
- }
- })
- // if (oldPoints == {
- // for (c <- clusters) yield c.points
- // })
- // result
- // else
- // kMeans(data, {
- // for (r <- result) yield new Cluster(r)
- // })
- }
- def main(args: Array[String]) {
- val data = filledData()
- //data.map(println(_))
- val centroids = randomK(values = data.values)
- centroids.map(println(_))
- val clusters = for (c <- centroids) yield new Cluster(c)
- println()
- println(kMeans(data, clusters))
- println()
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement