Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- print(__doc__)
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib import colors
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
- cmap = colors.LinearSegmentedColormap('red_blue_classes', {'red': [(0, 1, 1), (1, 0.7, 0.7)], 'green': [(0, 0.7, 0.7), (1, 0.7, 0.7)], 'blue': [(0, 0.7, 0.7), (1, 1, 1)]})
- plt.cm.register_cmap(cmap=cmap)
- def dataset_fixed_cov():
- n, dim = 300, 2
- np.random.seed(0)
- C = np.array([[0., -0.23], [0.83, .23]])
- X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C) + np.array([1, 1])]
- y = np.hstack((np.zeros(n), np.ones(n)))
- return X, y
- def dataset_cov():
- n, dim = 300, 2
- np.random.seed(0)
- C = np.array([[0., -1.], [2.5, .7]]) * 2.
- X = np.r_[np.dot(np.random.randn(n, dim), C), np.dot(np.random.randn(n, dim), C.T) + np.array([1, 4])]
- y = np.hstack((np.zeros(n), np.ones(n)))
- return X, y
- def plot_data(lda, X, y, y_pred, fig_index, plot_cov=True):
- splot = plt.subplot(2, 2, fig_index)
- plt.title('Linear Discriminant Analysis' if fig_index == 1 else 'Quadratic Discriminant Analysis' if fig_index == 2 else '')
- plt.ylabel('Data with\n fixed covariance' if fig_index == 1 else 'Data with\n varying covariances' if fig_index == 3 else '')
- tp = (y == y_pred)
- colors_list = ['red', 'blue']
- markers_list = ['.', 'x']
- for i in range(2):
- tp_i, fp_i, X_i = tp[y == i], ~tp[y == i], X[y == i]
- X_tp, X_fp = X_i[tp_i], X_i[fp_i]
- plt.scatter(X_tp[:, 0], X_tp[:, 1], marker=markers_list[i], color=colors_list[i])
- plt.scatter(X_fp[:, 0], X_fp[:, 1], marker=markers_list[i], s=20, color='#990000' if i == 0 else '#000099')
- nx, ny = 200, 100
- x_min, x_max, y_min, y_max = plt.xlim(), plt.ylim()
- xx, yy = np.meshgrid(np.linspace(x_min, x_max, nx), np.linspace(y_min, y_max, ny))
- Z = lda.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1].reshape(xx.shape)
- plt.pcolormesh(xx, yy, Z, cmap='red_blue_classes', norm=colors.Normalize(0., 1.), zorder=0)
- plt.contour(xx, yy, Z, [0.5], linewidths=2., colors='white')
- plt.plot(lda.means_[:, 0], lda.means_[:, 1], '*', color='yellow', markersize=15, markeredgecolor='grey')
- if plot_cov:
- for i in range(2):
- plot_ellipse(splot, lda.means_[i], lda.covariance_, colors_list[i]) if fig_index == 1 else plot_ellipse(splot, qda.means_[i], qda.covariance_[i], colors_list[i])
- splot.set_xticks(())
- splot.set_yticks(())
- plt.figure(figsize=(10, 8), facecolor='white')
- plt.suptitle('Linear Discriminant Analysis vs Quadratic Discriminant Analysis', y=0.98, fontsize=15)
- for i, (X, y) in enumerate([dataset_fixed_cov(), dataset_cov()]):
- lda, qda = LinearDiscriminantAnalysis(solver="svd", store_covariance=True), QuadraticDiscriminantAnalysis(store_covariance=True)
- plot_data(lda.fit(X, y), X, y, lda.predict(X), fig_index=2 * i + 1)
- plot_data(qda.fit(X, y), X, y, qda.predict(X), fig_index=2 * i + 2, plot_cov=(i == 1))
- plt.tight_layout()
- plt.subplots_adjust(top=0.92)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement