Advertisement
Dundre32

Untitled

Jan 6th, 2021
493
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.01 KB | None | 0 0
  1. def plot_svm(X,X_s, y, alpha, w0, kernel = 'linearkernel', sigma = 0.5):
  2.     plt.scatter(X[:,0], X[:,1], c = y)
  3.  
  4.     plt.scatter(X_s[:,0], X_s[:,1], s=350, facecolors='none', edgecolors='black')
  5.  
  6.     x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
  7.     y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
  8.     xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
  9.                          np.arange(y_min, y_max, h))
  10.  
  11.     X_mesh = np.c_[xx.ravel(), yy.ravel()]
  12.  
  13.  
  14.     #print(set(Z - 1))
  15.     #np.sign(
  16.     #Make predictions
  17.     Z = discriminant(alpha,w0,X,y, X_mesh, kernel = kernel, sigma = sigma).reshape(xx.shape)
  18.  
  19.     plt.contour(xx, yy, Z, [0.0], colors='k', linewidths=1, origin='lower')
  20.     plt.contour(xx, yy, Z + 1, [0.0], colors='grey', linewidths=3, origin='lower')
  21.     plt.contour(xx, yy, Z - 1, [0.0], colors='green', linewidths=3, origin='lower')
  22.  
  23.     plt.xlim(xx.min(), xx.max())
  24.     plt.ylim(yy.min(), yy.max())
  25.     plt.xticks(())
  26.     plt.yticks(())
  27.     plt.axis("tight")
  28.     plt.show()
  29.    
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement