Advertisement
fevzi02

Untitled

Nov 28th, 2023
20
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.09 KB | None | 0 0
  1. print(__doc__)
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from matplotlib import colors
  5. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
  6.  
  7. cmap = colors.LinearSegmentedColormap('red_blue_classes', {'red': [(0, 1, 1), (1, 0.7, 0.7)], 'green': [(0, 0.7, 0.7), (1, 0.7, 0.7)], 'blue': [(0, 0.7, 0.7), (1, 1, 1)]})
  8. plt.cm.register_cmap(cmap=cmap)
  9.  
  10. def dataset_fixed_cov():
  11. n, dim = 300, 2
  12. np.random.seed(0)
  13. C = np.array([[0., -0.23], [0.83, .23]])
  14. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C) + np.array([1, 1])]
  15. y = np.hstack((np.zeros(n), np.ones(n)))
  16. return X, y
  17.  
  18. def dataset_cov():
  19. n, dim = 300, 2
  20. np.random.seed(0)
  21. C = np.array([[0., -1.], [2.5, .7]]) * 2.
  22. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C.T) + np.array([1, 4])]
  23. y = np.hstack((np.zeros(n), np.ones(n)))
  24. return X, y
  25.  
  26. def plot_data(lda, X, y, y_pred, fig_index, plot_cov=True):
  27. splot = plt.subplot(2, 2, fig_index)
  28. plt.title('Linear Discriminant Analysis' if fig_index == 1 else 'Quadratic Discriminant Analysis' if fig_index == 2 else '')
  29. plt.ylabel('Data with\n fixed covariance' if fig_index == 1 else 'Data with\n varying covariances' if fig_index == 3 else '')
  30.  
  31. tp = (y == y_pred)
  32. colors_list = ['red', 'blue']
  33. markers_list = ['.', 'x']
  34.  
  35. for i in range(2):
  36. tp_i, fp_i, X_i = tp[y == i], ~tp[y == i], X[y == i]
  37. X_tp, X_fp = X_i[tp_i], X_i[fp_i]
  38. plt.scatter(X_tp[:, 0], X_tp[:, 1], marker=markers_list[i], color=colors_list[i])
  39. plt.scatter(X_fp[:, 0], X_fp[:, 1], marker=markers_list[i], s=20, color='#990000' if i == 0 else '#000099')
  40.  
  41. nx, ny = 200, 100
  42. x_min, x_max, y_min, y_max = plt.xlim(), plt.ylim()
  43. xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx), np.linspace(y_min, y_max, ny))
  44. Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1].reshape(xx.shape)
  45. plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes', norm=colors.Normalize(0., 1.), zorder=0)
  46. plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')
  47. plt.plot(lda.means_[:, 0], lda.means_[:, 1], '*', color='yellow', markersize=15, markeredgecolor='grey')
  48.  
  49. if plot_cov:
  50. for i in range(2):
  51. plot_ellipse(splot, lda.means_[i], lda.covariance_, colors_list[i]) if fig_index == 1 else plot_ellipse(splot, qda.means_[i], qda.covariance_[i], colors_list[i])
  52.  
  53. splot.set_xticks(())
  54. splot.set_yticks(())
  55.  
  56. plt.figure(figsize=(10, 8), facecolor='white')
  57. plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant Analysis', y=0.98, fontsize=15)
  58.  
  59. for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
  60. lda, qda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True), QuadraticDiscriminantAnalysis(store_covariance=True)
  61. plot_data(lda.fit(X, y), X, y, lda.predict(X), fig_index=2 * i + 1)
  62. plot_data(qda.fit(X, y), X, y, qda.predict(X), fig_index=2 * i + 2, plot_cov=(i == 1))
  63.  
  64. plt.tight_layout()
  65. plt.subplots_adjust(top=0.92)
  66. plt.show()
  67.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement