SHARE
TWEET

Untitled

a guest May 22nd, 2019 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn.datasets import make_blobs
  4.  
  5.  
  6. def plot_k_means(x, r, k, centers, colors):
  7.     # print(r[:20])
  8.     # plt.scatter(x[:,0], x[:,1], c=('red', 'blue', 'green'))
  9.     # plt.scatter(x[:,0], x[:,1], c=('black', 'black', 'black'))
  10.     plt.scatter(x[:,0], x[:,1], c=colors)
  11.  
  12.     for c in centers:
  13.         plt.plot(c[0], c[1], 'ro')
  14.     plt.show()
  15.  
  16.  
  17. def initialize_centers(x, num_k):
  18.     N, D = x.shape
  19.     centers = np.zeros((num_k, D))
  20.     used_idx = []
  21.     for k in range(num_k):
  22.         idx = np.random.choice(N)
  23.         while idx in used_idx:
  24.             idx = np.random.choice(N)
  25.         used_idx.append(idx)
  26.         centers[k] = x[idx]
  27.     return centers
  28.  
  29. def update_centers(x, r, K):
  30.     N, D = x.shape
  31.     centers = np.zeros((K, D))
  32.     for k in range(K):
  33.         centers[k] = r[:, k].dot(x) / r[:, k].sum()
  34.     return centers
  35.  
  36. def square_dist(a, b):
  37.     return (a - b) ** 2
  38.  
  39. def cost_func(x, r, centers, K):
  40.     cost = 0
  41.     for k in range(K):
  42.         norm = np.linalg.norm(x - centers[k], 2)
  43.         cost += (norm * np.expand_dims(r[:, k], axis=1) ).sum()
  44.     return cost
  45.  
  46.  
  47. def cluster_responsibilities(centers, x, beta):
  48.     N, _ = x.shape
  49.     K, D = centers.shape
  50.     R = np.zeros((N, K))
  51.  
  52.     for n in range(N):        
  53.         R[n] = np.exp(-beta * np.linalg.norm(centers - x[n], 2, axis=1))
  54.     R /= R.sum(axis=1, keepdims=True)
  55.  
  56.     return R
  57.  
  58. def return_responsibilities(R):
  59.     a = []
  60.     for i, r in enumerate(R):
  61.         r[0], r[2] = r[2], r[0]
  62.         a.append(np.argmax(r))
  63.         r[0], r[2] = r[2], r[0]
  64.     return a
  65.  
  66.  
  67. def calculate_partial_accuracy(X, labels, R):
  68.     responsibilities = return_responsibilities(R)
  69.     score = 0
  70.     for i, l in enumerate(labels):
  71.         if responsibilities[i] == l:
  72.             score += 1
  73.     return float(score) / len(responsibilities) * 100
  74.  
  75.  
  76. def soft_k_means(x, labels, K, max_iters=20, beta=1.):
  77.     np.random.seed(5)
  78.     random_colors = np.random.random((K, 3))
  79.     centers = initialize_centers(x, K)
  80.     # print centers
  81.     r = cluster_responsibilities(centers, x, beta)
  82.     # print r
  83.     colors = r.dot(random_colors)
  84.     print 'Initialize Plot'
  85.     # plot_k_means(x, r, K, centers, colors)
  86.     max_i = []
  87.     accuracies = []
  88.     prev_cost = 0
  89.     for i in range(max_iters):
  90.         r = cluster_responsibilities(centers, x, beta)
  91.         colors = r.dot(random_colors)
  92.         centers = update_centers(x, r, K)
  93.         cost = cost_func(x, r, centers, K)
  94.         print 'Iteration: ' + str(i)
  95.         print centers
  96.         # plot_k_means(x, r, K, centers, colors)
  97.         acc = calculate_partial_accuracy(X, labels, r)
  98.         accuracies.append(acc)
  99.         max_i.append(i)
  100.         # print 'Accuracy: ' +  str(acc) + '%'
  101.         if acc == 100:
  102.             print 'Breaking: Fully Accurate'
  103.             break
  104.         if np.abs(cost - prev_cost) < 1e-5:
  105.             print 'Breaking: Cost too high'
  106.             break
  107.         prev_cost = cost
  108.     print 'Finish Plot'
  109.     plot_k_means(x, r, K, centers, colors)
  110.     r = cluster_responsibilities(centers, x, beta)
  111.     acc = calculate_partial_accuracy(X, labels, r)
  112.     print max_i
  113.     # plt.plot(max_i, accuracies)
  114.     plt.show()
  115.     print 'Accuracy: ' +  str(acc) + '%'
  116.    
  117. X, labels = make_blobs(n_samples=100, centers=3, cluster_std=1.5, random_state=1)
  118. soft_k_means(X, labels, K=3)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top