Advertisement
Guest User

Untitled

a guest
Dec 8th, 2016
318
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.99 KB | None | 0 0
  1. import tensorflow as tf
  2. import numpy as np
  3. import time
  4.  
  5. N=10000
  6. K=4
  7. MAX_ITERS = 1000
  8.  
  9. start = time.time()
  10.  
  11. points = tf.Variable(tf.random_uniform([N,2]))
  12. cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))
  13.  
  14. # Silly initialization: Use the first K points as the starting
  15. # centroids. In the real world, do this better.
  16. centroids = tf.Variable(tf.slice(points.initialized_value(), [0,0], [K,2]))
  17.  
  18. # Replicate to N copies of each centroid and K copies of each
  19. # point, then subtract and compute the sum of squared distances.
  20. rep_centroids = tf.reshape(tf.tile(centroids, [N, 1]), [N, K, 2])
  21. rep_points = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
  22. sum_squares = tf.reduce_sum(tf.square(rep_points - rep_centroids),
  23. reduction_indices=2)
  24.  
  25. # Use argmin to select the lowest-distance point
  26. best_centroids = tf.argmin(sum_squares, 1)
  27. did_assignments_change = tf.reduce_any(tf.not_equal(best_centroids,
  28. cluster_assignments))
  29.  
  30. def bucket_mean(data, bucket_ids, num_buckets):
  31. total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets)
  32. count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets)
  33. return total / count
  34.  
  35. means = bucket_mean(points, best_centroids, K)
  36.  
  37. # Do not write to the assigned clusters variable until after
  38. # computing whether the assignments have changed - hence with_dependencies
  39. with tf.control_dependencies([did_assignments_change]):
  40. do_updates = tf.group(
  41. centroids.assign(means),
  42. cluster_assignments.assign(best_centroids))
  43.  
  44. init = tf.initialize_all_variables()
  45.  
  46. sess = tf.Session()
  47. sess.run(init)
  48.  
  49. changed = True
  50. iters = 0
  51.  
  52. while changed and iters < MAX_ITERS:
  53. iters += 1
  54. [changed, _] = sess.run([did_assignments_change, do_updates])
  55.  
  56. [centers, assignments] = sess.run([centroids, cluster_assignments])
  57. end = time.time()
  58. print ("Found in %.2f seconds" % (end-start)), iters, "iterations"
  59. print "Centroids:"
  60. print centers
  61. print "Cluster assignments:", assignments
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement