Advertisement
toweber

knn_aula_p4_scaling

Sep 17th, 2021
208
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.07 KB | None | 0 0
  1. from sklearn import neighbors
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import ipdb
  5.  
  6. X = np.random.random([1000,2])
  7. Y = []
  8. for x_value in X:
  9.     y_value = 0
  10.     if ( ((x_value[0] <= 0.5) and (x_value[1] > 0.5))  or ((x_value[1] <= 0.5) and (x_value[0] > 0.5 ))  ):   # XOR "logic"
  11.         y_value = 1
  12.  
  13.     Y.append(y_value)
  14.  
  15.  
  16. X = np.array(X)
  17. Y = np.array(Y)
  18.  
  19. X[:,1] = 1*X[:,1]
  20. clf = neighbors.KNeighborsClassifier(n_neighbors=1)
  21.  
  22. clf = clf.fit(X, Y)
  23.  
  24.  
  25. #****************
  26. # Plot
  27. # baseado em https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
  28. #****************
  29. step_size = 0.01
  30. # encontrando os limites
  31. x_min = X[:,0].min()
  32. x_max = X[:,0].max()
  33. y_min = X[:,1].min()
  34. y_max = X[:,1].max()
  35.  
  36. xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size))
  37.  
  38. Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
  39.  
  40. Z = Z.reshape(xx.shape)
  41.  
  42. plt.figure()
  43. plt.pcolormesh(xx, yy, Z, cmap='Set3')
  44.  
  45. # plot training points
  46. plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolor='k', s=20, cmap='Set3')
  47.  
  48. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement