• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Untitled

a guest Nov 19th, 2019 106 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import numpy as np
2. import matplotlib.pyplot as plt
3. def aPointToAll(z, X):
4.   X2 = np.sum(X*X, axis=1)
5.   print(X2.shape)
6.   z2 = np.sum(z*z)
7.   return X2 + z2 -2*X.dot(z.T)
8.
9. def KNN(k,z):
10.   dist = aPointToAll(z,X)
11.   listIndex = np.argpartition(dist, k)
12.   s = np.sum(lables[listIndex])
13.   if s > len(lables)/2:
14.     return 1
15.   return 0
16.
17. X = np.array([[1,12], [2,5], [5,3], [3,2],[3,6],[1.5,9],[7,2],[6,1],[3.8,3],[3,10],[5.6,4],[4,2],[3.5,8],[2,11],[2.5,5],[2,9],[1,7]])
18. lables =np.array([0,0,1,1,0,1,1,1,1,0,1,1,0,0,1,0,0])
19. z = np.array([[2,2]])
20. gr1 = X[lables==0,:]
21. gr2 = X[lables==1,:]
22. plt.scatter(gr2[:,0], gr2[:,1], c='#9467bd') #1 tim
23. plt.scatter(gr1[:,0], gr1[:,1], c='#17becf') #0 xanh
24. if KNN(3,z) == 0:
25.   plt.scatter(z[:,0], z[:,1], c='#17becf',marker='v',s=100)
26. else:
27.   plt.scatter(z[:,0], z[:,1], c='#9467bd',marker='v',s=100)
28. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top