• API
• FAQ
• Tools
• Archive
SHARE
TWEET

Untitled

a guest Dec 3rd, 2019 87 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. # -*- coding: utf-8 -*-
2. """
3. Created on Tue Dec  3 15:48:32 2019
4.
5. @author: Eric
6. """
7. import numpy as np
8. import seaborn as sns
9. import matplotlib
10. import matplotlib.pyplot as plt
11.
12.
13. sns.set_style("whitegrid")
14.
15. blue, = sns.color_palette("muted", 1)
16.
17. def plot_mnist(elts, m, n, name = -99):
18.     """Plot MNIST images in an m by n table. Note that we crop the images
19.    so that they appear reasonably close together.  Note that we are
20.    passed raw MNIST data and it is reshaped.
21.
22.    Example: plot_mnist(X_train, 10, 10)
23.    """
24.     fig = plt.figure()
25.     images = [elt.reshape(28, 28) for elt in elts]
26.     img = np.concatenate([np.concatenate([images[m*y+x] for x in range(m)], axis=1)
27.                           for y in range(n)], axis=0)
28.     ax = fig.add_subplot(1, 1, 1)
29.     ax.matshow(img, cmap = matplotlib.cm.binary)
30.     for spine in plt.gca().spines.values():
31.         spine.set_visible(False)
32.     plt.xticks(np.array([]))
33.     plt.yticks(np.array([]))
34.     plt.show()
35.     if (name > -2):
36.         title = "fig_"
37.         name = str(name)
38.         title = title + name + ".png"
39.         plt.savefig(title)
40.
41.
42. def PCA(X):
43.     n,m = X.shape
44.     #assert np.allclose(X.mean(axis=0),np.zeros(m))
45.
46.     #Data is centered, Get Covariance
47.     coVar = np.dot(X.T, X) / (n -1)
48.     #coVar = np.cov(X)
49.     #get eigens
50.     eVal, eVec = np.linalg.eig(coVar)
51.
52.     xPca = np.dot(X, eVec)
53.
54.     return xPca, eVal, eVec
55.
56.
57. def main():
58.     X_train = np.load('../Data/X_train.npy')
59.     X_val = np.load('../Data/X_val.npy')
60.
61.     Y_train = np.load('../Data/Y_train.npy')
62.
63.     #X_train = X_train.astype(np.float64)
64.     m = X_train.mean(axis = 0)
65.
66.     X = X_train - m
67.     #print(X_train.shape)
68.     #print(Y_train.shape)
69.     """
70.    meanImage = m.reshape(28,28)
71.    plt.axis('off')
72.    plt.imshow(meanImage,cmap = matplotlib.cm.binary)
73.    """
74.     avgImages = []
75.
76.     """
77.    for i in range(10):
78.        meanDig = X_train[Y_train == i].mean(axis=0)
79.        print (meanDig.shape)
80.        meanDig = meanDig.reshape(28,28)
81.        avgImages.append(meanDig)
82.
83.    plot_mnist(avgImages,5,2)
84.    """
85.
86.     xPca, eVal, eVec = PCA(X)
87.
88.     for i in range(5):
89.         eigImg = eVec[i].reshape(28,28)
90.         avgImages.append(eigImg)
91.
92.     plot_mnist(avgImages,5,1)
93.     """
94.    print(eVal.shape)
95.    numPC = np.arange(eVal.shape[0])[:25]
96.    largest = eVal[:25]
97.    fig, ax = plt.subplots()
98.    ax.plot(numPC, largest, color = blue, lw = 3)
99.    ax.set_ylabel("Eigen Values")
100.    ax.set_xlabel("PCA Number")
101.    ax.fill_between(numPC,0,largest, alpha = .25)
102.    ax.set(xlim=(0, len(numPC) - 1), ylim=(0, None))
103.    """
104.
105.
106.
107.
108. if __name__ == '__main__':
109.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top