Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
114
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.00 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4.  
  5. cl_num = 3
  6. data_num = 20
  7. thr = [0.00001, 0.00001, 0.00001]
  8.  
  9.  
  10. def dist(x, y, mu_x, mu_y):
  11. return ((mu_x - x)**2 + (mu_y - y)**2)
  12.  
  13.  
  14. def cluster(x, y, mu_x, mu_y):
  15.  
  16. cls_ = dict()
  17. for i in range(data_num):
  18. dists = []
  19. for j in range(cl_num):
  20. distant = dist(x[i], y[i], mu_x[j], mu_y[j])
  21. dists.append(distant)
  22. cl = dists.index(min(dists))
  23. if cl not in cls_:
  24. cls_[cl] = [(x[i], y[i])]
  25. elif cl in cls_:
  26. cls_[cl].append((x[i], y[i]))
  27.  
  28. return cls_
  29.  
  30.  
  31. def re_mu(cls_, mu_x, mu_y):
  32. new_muX = []
  33. new_muY = []
  34.  
  35. for key, values in cls_.items():
  36.  
  37. if len(values) == 0:
  38. values.append([mu_x[key], mu_y[key]])
  39.  
  40. sum_x = 0
  41. sum_y = 0
  42. for v in values:
  43. sum_x += v[0]
  44. sum_y += v[1]
  45.  
  46. new_mu_x = sum_x / len(values)
  47. new_mu_y = sum_y / len(values)
  48.  
  49. new_muX.append(round(new_mu_x, 2))
  50. new_muY.append(round(new_mu_y, 2))
  51. return new_muX, new_muY
  52.  
  53.  
  54. def main():
  55.  
  56. x = np.random.randint(0, 500, data_num)
  57. y = np.random.randint(0, 500, data_num)
  58.  
  59. mu_x = np.random.randint(0, 500, cl_num)
  60. mu_y = np.random.randint(0, 500, cl_num)
  61.  
  62. cls_ = cluster(x, y, mu_x, mu_y)
  63.  
  64. new_muX, new_muY = re_mu(cls_, mu_x, mu_y)
  65.  
  66. while any((abs(np.array(new_muX) - np.array(mu_x)) > thr)) != False or any(
  67. (abs(np.array(new_muY) - np.array(mu_y)) > thr)) != False:
  68. mu_x = new_muX
  69. mu_y = new_muY
  70. cls_ = cluster(x, y, mu_x, mu_y)
  71. new_muX, new_muY = re_mu(cls_, mu_x, mu_y)
  72.  
  73. print('Done')
  74.  
  75. plt.scatter(x, y)
  76. plt.scatter(new_muX, new_muY)
  77. plt.show()
  78.  
  79. colors = ['r', 'b', 'g']
  80. for key, values in cls_.items():
  81. cx = []
  82. cy = []
  83. for v in values:
  84. cx.append(v[0])
  85. cy.append(v[1])
  86. plt.scatter(cx, cy, color=colors[key])
  87.  
  88. plt.show()
  89.  
  90.  
  91. if __name__ == '__main__':
  92. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement