# Untitled

Nov 28th, 2023
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. print(__doc__)
2. from scipy import linalg
3. import numpy as np
4. import matplotlib.pyplot as plt
5. import matplotlib as mpl
6. from matplotlib import colors
7. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
8.
9. cmap = colors.LinearSegmentedColormap('red_blue_classes', {'red': [(0, 1, 1), (1, 0.7, 0.7)], 'green': [(0, 0.7, 0.7), (1, 0.7, 0.7)], 'blue': [(0, 0.7, 0.7), (1, 1, 1)]})
10. plt.cm.register_cmap(cmap=cmap)
11.
12. def dataset_fixed_cov():
13. n, dim = 300, 2
14. np.random.seed(0)
15. C = np.array([[0., -0.23], [0.83, .23]])
16. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C) + np.array([1, 1])]
17. y = np.hstack((np.zeros(n), np.ones(n)))
18. return X, y
19.
20. def dataset_cov():
21. n, dim = 300, 2
22. np.random.seed(0)
23. C = np.array([[0., -1.], [2.5, .7]]) * 2.
24. X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C.T) + np.array([1, 4])]
25. y = np.hstack((np.zeros(n), np.ones(n)))
26. return X, y
27.
28. def plot_data(lda, X, y, y_pred, fig_index, plot_cov):
29. splot = plt.subplot(2, 2, fig_index)
30. plt.title('Linear Discriminant Analysis' if fig_index == 1 else 'Quadratic Discriminant Analysis' if fig_index == 2 else '')
31. plt.ylabel('Data with\n fixed covariance' if fig_index == 1 else 'Data with\n varying covariances' if fig_index == 3 else '')
32.
33. tp = (y == y_pred)
34. colors_list = ['red', 'blue']
35. markers_list = ['.', 'x']
36.
37. for i in range(2):
38. tp_i, fp_i = tp[y == i], ~tp[y == i]
39. X_i = X[y == i]
40. X_tp, X_fp = X_i[tp_i], X_i[fp_i]
41. plt.scatter(X_tp[:, 0], X_tp[:, 1], marker=markers_list[i], color=colors_list[i])
42. plt.scatter(X_fp[:, 0], X_fp[:, 1], marker=markers_list[i], s=20, color='#990000' if i == 0 else '#000099')
43.
44. nx, ny = 200, 100
45. x_min, x_max = plt.xlim()
46. y_min, y_max = plt.ylim()
47. xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx), np.linspace(y_min, y_max, ny))
48. Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1].reshape(xx.shape)
49. plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes', norm=colors.Normalize(0., 1.), zorder=0)
50. plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')
51. plt.plot(lda.means_[0][0], lda.means_[0][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
52. plt.plot(lda.means_[1][0], lda.means_[1][1], '*', color='yellow', markersize=15, markeredgecolor='grey')
53.
54. if plot_cov:
55. plot_ellipse(splot, lda.means_[0], lda.covariance_, 'red')
56. plot_ellipse(splot, lda.means_[1], lda.covariance_, 'blue') if fig_index != 1 else plot_ellipse(splot, qda.means_[1], qda.covariance_[1], 'blue')
57.
58. splot.set_xticks(())
59. splot.set_yticks(())
60.
61. plt.figure(figsize=(10, 8), facecolor='white')
62. plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant Analysis', y=0.98, fontsize=15)
63.
64. for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
65. lda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True)
66. y_pred = lda.fit(X, y).predict(X)
67. plot_data(lda, X, y, y_pred, fig_index=2 * i + 1, plot_cov=True)
68.