Advertisement
Guest User

Untitled

a guest
Nov 19th, 2019
141
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.85 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. def aPointToAll(z, X):
  4. X2 = np.sum(X*X, axis=1)
  5. print(X2.shape)
  6. z2 = np.sum(z*z)
  7. return X2 + z2 -2*X.dot(z.T)
  8.  
  9. def KNN(k,z):
  10. dist = aPointToAll(z,X)
  11. listIndex = np.argpartition(dist, k)
  12. s = np.sum(lables[listIndex])
  13. if s > len(lables)/2:
  14. return 1
  15. return 0
  16.  
  17. X = np.array([[1,12], [2,5], [5,3], [3,2],[3,6],[1.5,9],[7,2],[6,1],[3.8,3],[3,10],[5.6,4],[4,2],[3.5,8],[2,11],[2.5,5],[2,9],[1,7]])
  18. lables =np.array([0,0,1,1,0,1,1,1,1,0,1,1,0,0,1,0,0])
  19. z = np.array([[2,2]])
  20. gr1 = X[lables==0,:]
  21. gr2 = X[lables==1,:]
  22. plt.scatter(gr2[:,0], gr2[:,1], c='#9467bd') #1 tim
  23. plt.scatter(gr1[:,0], gr1[:,1], c='#17becf') #0 xanh
  24. if KNN(3,z) == 0:
  25. plt.scatter(z[:,0], z[:,1], c='#17becf',marker='v',s=100)
  26. else:
  27. plt.scatter(z[:,0], z[:,1], c='#9467bd',marker='v',s=100)
  28. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement