Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import matplotlib.pyplot as plt
- def aPointToAll(z, X):
- X2 = np.sum(X*X, axis=1)
- print(X2.shape)
- z2 = np.sum(z*z)
- return X2 + z2 -2*X.dot(z.T)
- def KNN(k,z):
- dist = aPointToAll(z,X)
- listIndex = np.argpartition(dist, k)
- s = np.sum(lables[listIndex])
- if s > len(lables)/2:
- return 1
- return 0
- X = np.array([[1,12], [2,5], [5,3], [3,2],[3,6],[1.5,9],[7,2],[6,1],[3.8,3],[3,10],[5.6,4],[4,2],[3.5,8],[2,11],[2.5,5],[2,9],[1,7]])
- lables =np.array([0,0,1,1,0,1,1,1,1,0,1,1,0,0,1,0,0])
- z = np.array([[2,2]])
- gr1 = X[lables==0,:]
- gr2 = X[lables==1,:]
- plt.scatter(gr2[:,0], gr2[:,1], c='#9467bd') #1 tim
- plt.scatter(gr1[:,0], gr1[:,1], c='#17becf') #0 xanh
- if KNN(3,z) == 0:
- plt.scatter(z[:,0], z[:,1], c='#17becf',marker='v',s=100)
- else:
- plt.scatter(z[:,0], z[:,1], c='#9467bd',marker='v',s=100)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement