Advertisement
fevzi02

Untitled

Nov 28th, 2023
22
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.34 KB | None | 0 0
  1. print(__doc__)
  2. from scipy import linalg
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import matplotlib as mpl
  6. from matplotlib import colors
  7. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
  8.  
  9. cmap = colors.LinearSegmentedColormap('red_blue_classes', {'red': [(0, 1, 1), (1, 0.7, 0.7)], 'green': [(0, 0.7, 0.7), (1, 0.7, 0.7)], 'blue': [(0, 0.7, 0.7), (1, 1, 1)]})
  10. plt.cm.register_cmap(cmap=cmap)
  11.  
  12. def dataset_fixed_cov():
  13. n, dim = 300, 2
  14. np.random.seed(0)
  15. C = np.array([[0., -0.23], [0.83, .23]])
  16. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C) + np.array([1, 1])]
  17. y = np.hstack((np.zeros(n), np.ones(n)))
  18. return X, y
  19.  
  20. def dataset_cov():
  21. n, dim = 300, 2
  22. np.random.seed(0)
  23. C = np.array([[0., -1.], [2.5, .7]]) * 2.
  24. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C.T) + np.array([1, 4])]
  25. y = np.hstack((np.zeros(n), np.ones(n)))
  26. return X, y
  27.  
  28. def plot_data(lda, X, y, y_pred, fig_index, plot_cov):
  29. splot = plt.subplot(2, 2, fig_index)
  30. plt.title('Linear Discriminant Analysis' if fig_index == 1 else 'Quadratic Discriminant Analysis' if fig_index == 2 else '')
  31. plt.ylabel('Data with\n fixed covariance' if fig_index == 1 else 'Data with\n varying covariances' if fig_index == 3 else '')
  32.  
  33. tp = (y == y_pred)
  34. colors_list = ['red', 'blue']
  35. markers_list = ['.', 'x']
  36.  
  37. for i in range(2):
  38. tp_i, fp_i = tp[y == i], ~tp[y == i]
  39. X_i = X[y == i]
  40. X_tp, X_fp = X_i[tp_i], X_i[fp_i]
  41. plt.scatter(X_tp[:, 0], X_tp[:, 1], marker=markers_list[i], color=colors_list[i])
  42. plt.scatter(X_fp[:, 0], X_fp[:, 1], marker=markers_list[i], s=20, color='#990000' if i == 0 else '#000099')
  43.  
  44. nx, ny = 200, 100
  45. x_min, x_max = plt.xlim()
  46. y_min, y_max = plt.ylim()
  47. xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx), np.linspace(y_min, y_max, ny))
  48. Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1].reshape(xx.shape)
  49. plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes', norm=colors.Normalize(0., 1.), zorder=0)
  50. plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')
  51. plt.plot(lda.means_[0][0], lda.means_[0][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
  52. plt.plot(lda.means_[1][0], lda.means_[1][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
  53.  
  54. if plot_cov:
  55. plot_ellipse(splot, lda.means_[0], lda.covariance_, 'red')
  56. plot_ellipse(splot, lda.means_[1], lda.covariance_, 'blue') if fig_index != 1 else plot_ellipse(splot, qda.means_[1], qda.covariance_[1], 'blue')
  57.  
  58. splot.set_xticks(())
  59. splot.set_yticks(())
  60.  
  61. plt.figure(figsize=(10, 8), facecolor='white')
  62. plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant Analysis', y=0.98, fontsize=15)
  63.  
  64. for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
  65. lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
  66. y_pred = lda.fit(X, y).predict(X)
  67. plot_data(lda, X, y, y_pred, fig_index=2 * i + 1, plot_cov=True)
  68.  
  69. qda = QuadraticDiscriminantAnalysis(store_covariance=True)
  70. y_pred = qda.fit(X, y).predict(X)
  71. plot_data(qda, X, y, y_pred, fig_index=2 * i + 2, plot_cov=(i == 1))
  72.  
  73. plt.tight_layout()
  74. plt.subplots_adjust(top=0.92)
  75. plt.show()
  76.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement