Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn import neighbors
- import numpy as np
- import matplotlib.pyplot as plt
- import ipdb
- X = np.random.random([1000,2])
- Y = []
- for x_value in X:
- y_value = 0
- if ( ((x_value[0] <= 0.5) and (x_value[1] > 0.5)) or ((x_value[1] <= 0.5) and (x_value[0] > 0.5 )) ): # XOR "logic"
- y_value = 1
- Y.append(y_value)
- X = np.array(X)
- Y = np.array(Y)
- X[:,1] = 1*X[:,1]
- clf = neighbors.KNeighborsClassifier(n_neighbors=1)
- clf = clf.fit(X, Y)
- #****************
- # Plot
- # baseado em https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
- #****************
- step_size = 0.01
- # encontrando os limites
- x_min = X[:,0].min()
- x_max = X[:,0].max()
- y_min = X[:,1].min()
- y_max = X[:,1].max()
- xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size))
- Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
- Z = Z.reshape(xx.shape)
- plt.figure()
- plt.pcolormesh(xx, yy, Z, cmap='Set3')
- # plot training points
- plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolor='k', s=20, cmap='Set3')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement