Advertisement
fevzi02

Untitled

Nov 28th, 2023 (edited)
22
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.57 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import matplotlib.patches as patches
  4. from scipy import linalg
  5. from matplotlib import colors
  6. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
  7.  
  8. # #############################################################################
  9. # Colormap
  10. cmap = colors.LinearSegmentedColormap(
  11. 'red_blue_classes',
  12. {'red': [(0, 1, 1), (1, 0.7, 0.7)],
  13. 'green': [(0, 0.7, 0.7), (1, 0.7, 0.7)],
  14. 'blue': [(0, 0.7, 0.7), (1, 1, 1)]})
  15. plt.cm.register_cmap(cmap=cmap)
  16.  
  17. # #############################################################################
  18. # Generate datasets
  19. def dataset_fixed_cov():
  20. n, dim = 300, 2
  21. np.random.seed(0)
  22. C = np.array([[0., -0.23], [0.83, .23]])
  23. X = np.r_[np.dot(np.random.randn(n, dim), C),
  24. np.dot(np.random.randn(n, dim), C) + np.array([1, 1])]
  25. y = np.hstack((np.zeros(n), np.ones(n)))
  26. return X, y
  27.  
  28. def dataset_cov():
  29. n, dim = 300, 2
  30. np.random.seed(0)
  31. C = np.array([[0., -1.], [2.5, .7]]) * 2.
  32. X = np.r_[np.dot(np.random.randn(n, dim), C),
  33. np.dot(np.random.randn(n, dim), C.T) + np.array([1, 4])]
  34. y = np.hstack((np.zeros(n), np.ones(n)))
  35. return X, y
  36.  
  37. # #############################################################################
  38. # Plot functions
  39. def plot_data(lda, X, y, y_pred, fig_index):
  40. splot = plt.subplot(2, 2, fig_index)
  41. if fig_index == 1:
  42. plt.title('Линейный дискриминантный анализ')
  43. plt.ylabel('Данные с\n фиксированной ковариацией')
  44. elif fig_index == 2:
  45. plt.title('Квадратичный дискриминантный анализ')
  46. elif fig_index == 3:
  47. plt.ylabel('Данные с\n изменяющимися ковариациями')
  48.  
  49. tp = (y == y_pred)
  50. tp0, tp1 = tp[y == 0], tp[y == 1]
  51. X0, X1 = X[y == 0], X[y == 1]
  52. X0_tp, X0_fp = X0[tp0], X0[~tp0]
  53. X1_tp, X1_fp = X1[tp1], X1[~tp1]
  54.  
  55. plt.scatter(X0_tp[:, 0], X0_tp[:, 1], marker='.', color='red')
  56. plt.scatter(X0_fp[:, 0], X0_fp[:, 1], marker='x', s=20, color='#990000')
  57. plt.scatter(X1_tp[:, 0], X1_tp[:, 1], marker='.', color='blue')
  58. plt.scatter(X1_fp[:, 0], X1_fp[:, 1], marker='x', s=20, color='#000099')
  59.  
  60. # Вставляем код для построения эллипса для класса 0
  61. v, w = linalg.eigh(lda.covariance_)
  62. u = w[0] / linalg.norm(w[0])
  63. angle = np.arctan(u[1] / u[0])
  64. angle = 180 * angle / np.pi
  65. ell = patches.Ellipse(lda.means_[0], 2 * v[0] ** 0.5, 2 * v[1] ** 0.5,
  66. 180 + angle, facecolor='none', edgecolor='red', linewidth=2)
  67. splot.add_patch(ell)
  68.  
  69. # Вставляем код для построения эллипса для класса 1
  70. v, w = linalg.eigh(lda.covariance_)
  71. u = w[0] / linalg.norm(w[0])
  72. angle = np.arctan(u[1] / u[0])
  73. angle = 180 * angle / np.pi
  74. ell = patches.Ellipse(lda.means_[1], 2 * v[0] ** 0.5, 2 * v[1] ** 0.5,
  75. 180 + angle, facecolor='none', edgecolor='blue', linewidth=2)
  76. splot.add_patch(ell)
  77.  
  78. nx, ny = 200, 100
  79. x_min, x_max = plt.xlim()
  80. y_min, y_max = plt.ylim()
  81. xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx),
  82. np.linspace(y_min, y_max, ny))
  83. Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])
  84. Z = Z[:, 1].reshape(xx.shape)
  85. plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes', norm=colors.Normalize(0., 1.), zorder=0)
  86. plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')
  87.  
  88. plt.plot(lda.means_[0][0], lda.means_[0][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
  89. plt.plot(lda.means_[1][0], lda.means_[1][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
  90.  
  91. splot.set_xticks(())
  92. splot.set_yticks(())
  93.  
  94. # #############################################################################
  95. # Main Code
  96. plt.figure(figsize=(10, 8), facecolor='white')
  97. plt.suptitle('Линейный Дискриминантный Анализ vs Квадратичный Дискриминантный Анализ',
  98. y=0.98, fontsize=15)
  99.  
  100. for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
  101. lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
  102. y_pred = lda.fit(X, y).predict(X)
  103. splot = plot_data(lda, X, y, y_pred, fig_index=2 * i + 1)
  104. plt.axis('tight')
  105.  
  106. qda = QuadraticDiscriminantAnalysis(store_covariance=True)
  107. y_pred = qda.fit(X, y).predict(X)
  108.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement