SHARE
TWEET

Untitled

a guest Nov 19th, 2019 106 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. def aPointToAll(z, X):
  4.   X2 = np.sum(X*X, axis=1)
  5.   print(X2.shape)
  6.   z2 = np.sum(z*z)
  7.   return X2 + z2 -2*X.dot(z.T)
  8.  
  9. def KNN(k,z):
  10.   dist = aPointToAll(z,X)
  11.   listIndex = np.argpartition(dist, k)
  12.   s = np.sum(lables[listIndex])
  13.   if s > len(lables)/2:
  14.     return 1
  15.   return 0
  16.  
  17. X = np.array([[1,12], [2,5], [5,3], [3,2],[3,6],[1.5,9],[7,2],[6,1],[3.8,3],[3,10],[5.6,4],[4,2],[3.5,8],[2,11],[2.5,5],[2,9],[1,7]])
  18. lables =np.array([0,0,1,1,0,1,1,1,1,0,1,1,0,0,1,0,0])
  19. z = np.array([[2,2]])
  20. gr1 = X[lables==0,:]
  21. gr2 = X[lables==1,:]
  22. plt.scatter(gr2[:,0], gr2[:,1], c='#9467bd') #1 tim
  23. plt.scatter(gr1[:,0], gr1[:,1], c='#17becf') #0 xanh
  24. if KNN(3,z) == 0:
  25.   plt.scatter(z[:,0], z[:,1], c='#17becf',marker='v',s=100)
  26. else:
  27.   plt.scatter(z[:,0], z[:,1], c='#9467bd',marker='v',s=100)
  28. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top