• API
• FAQ
• Tools
• Archive
SHARE
TWEET # Untitled garchangel  Aug 25th, 2019 57 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. def generate_data(n_samples, n_features):
2.     X, y = make_blobs(n_samples=n_samples, n_features=1, centers=[[-2], ])
3.     if n_features > 1:
4.         X = np.hstack([X, np.random.randn(n_samples, n_features - 1)])
5.     return X, y
6. acc_clf1, acc_clf2 = [], []
7. n_features_range = range(1, n_features_max + 1, step)
8. for n_features in n_features_range:
9.     score_clf1, score_clf2 = 0, 0
10.     for _ in range(n_averages):
11.         X, y = generate_data(n_train, n_features)
12.         clf1 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto').fit(X, y)
13.         clf2 = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X, y)
14.         X, y = generate_data(n_test, n_features)
15.         score_clf1 += clf1.score(X, y)
16.         score_clf2 += clf2.score(X, y)
17.     acc_clf1.append(score_clf1 / n_averages)
18.     acc_clf2.append(score_clf2 / n_averages)
19. features_samples_ratio = np.array(n_features_range) / n_train
20. plt.plot(features_samples_ratio, acc_clf1, linewidth=2,
21.          label="Linear Discriminant Analysis with shrinkage", color='navy')
22. plt.plot(features_samples_ratio, acc_clf2, linewidth=2,
23.          label="Linear Discriminant Analysis", color='gold')
24. plt.xlabel('n_features / n_samples')
25. plt.ylabel('Classification accuracy')
26. plt.legend(loc=1, prop={'size': 12})
27. plt.suptitle('Linear Discriminant Analysis vs. \
28. shrinkage Linear Discriminant Analysis (1 discriminative feature)')
29. plt.show()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top