toweber

k-means_visualization

Sep 17th, 2021 (edited)
305
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.85 KB | None | 0 0
  1. import os
  2. import ipdb # ipdb.set_trace()  
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. from sklearn.cluster import KMeans
  6.  
  7. np.random.seed(0)
  8.  
  9. samples = 1000 #por grupo
  10. # grupo artificialmente criado 1, ao redor do ponto [5,10], com desvio padrão 0.5
  11.  
  12. coordX_1 = 5+2.5*np.random.randn(samples)
  13. coordY_1 = 10+2.5*np.random.randn(samples)
  14.  
  15. # grupo artificialmente criado 2 ao redor do ponto [2,2], com desvio padrão 0.5
  16. coordX_2 = 2+0.5*np.random.randn(samples)
  17. coordY_2 = 2+0.5*np.random.randn(samples)
  18.  
  19. # grupo artificialmente criado 3 ao redor do ponto [7,3], com desvio padrão 0.5
  20. coordX_3 = 7+0.5*np.random.randn(samples)
  21. coordY_3 = 3+0.5*np.random.randn(samples)
  22.  
  23.  
  24. coordX = np.concatenate([coordX_1, coordX_2, coordX_3])
  25. coordY = np.concatenate([coordY_1, coordY_2, coordY_3])
  26.  
  27. X = np.column_stack([coordX,coordY])
  28.  
  29.  
  30. #plt.scatter(X[:,0],X[:,1])
  31. #plt.show()
  32.  
  33. #ipdb.set_trace()
  34.  
  35. kmeans = KMeans(n_clusters=5, random_state=0).fit(X)
  36.  
  37. print(kmeans.labels_)
  38. print(kmeans.cluster_centers_)
  39.  
  40.  
  41.  
  42.  
  43. # plot the decision boundaries, based on https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html
  44. step = 0.1
  45. x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
  46. y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
  47.  
  48. xx, yy = np.meshgrid(np.arange(x_min, x_max, step), np.arange(y_min, y_max, step))
  49.  
  50. Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
  51.  
  52. Z = Z.reshape(xx.shape)
  53.  
  54.  
  55. # plotar o resultado dos clusters baseado em cor
  56. plt.imshow(Z, interpolation='nearest',
  57.            extent=(xx.min(), xx.max(), yy.min(), yy.max()),
  58.            cmap=plt.cm.Paired,
  59.            aspect='auto', origin='lower')
  60.  
  61.  
  62. # plotar os dados
  63. plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)
  64.  
  65.  
  66. # plotar os centroides
  67. centroids = kmeans.cluster_centers_
  68.  
  69. plt.plot(centroids[:,0],centroids[:,1], 'r.', markersize=10)
  70.  
  71. plt.show()
  72.  
Add Comment
Please, Sign In to add comment