SHARE
TWEET

Untitled

a guest Dec 3rd, 2019 87 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Dec  3 15:48:32 2019
  4.  
  5. @author: Eric
  6. """
  7. import numpy as np
  8. import seaborn as sns
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11.  
  12.  
  13. sns.set_style("whitegrid")
  14.  
  15. blue, = sns.color_palette("muted", 1)
  16.  
  17. def plot_mnist(elts, m, n, name = -99):
  18.     """Plot MNIST images in an m by n table. Note that we crop the images
  19.    so that they appear reasonably close together.  Note that we are
  20.    passed raw MNIST data and it is reshaped.
  21.  
  22.    Example: plot_mnist(X_train, 10, 10)
  23.    """
  24.     fig = plt.figure()
  25.     images = [elt.reshape(28, 28) for elt in elts]
  26.     img = np.concatenate([np.concatenate([images[m*y+x] for x in range(m)], axis=1)
  27.                           for y in range(n)], axis=0)
  28.     ax = fig.add_subplot(1, 1, 1)
  29.     ax.matshow(img, cmap = matplotlib.cm.binary)
  30.     for spine in plt.gca().spines.values():
  31.         spine.set_visible(False)
  32.     plt.xticks(np.array([]))
  33.     plt.yticks(np.array([]))
  34.     plt.show()
  35.     if (name > -2):
  36.         title = "fig_"
  37.         name = str(name)
  38.         title = title + name + ".png"
  39.         plt.savefig(title)
  40.  
  41.  
  42. def PCA(X):
  43.     n,m = X.shape
  44.     #assert np.allclose(X.mean(axis=0),np.zeros(m))
  45.    
  46.     #Data is centered, Get Covariance
  47.     coVar = np.dot(X.T, X) / (n -1)
  48.     #coVar = np.cov(X)
  49.     #get eigens
  50.     eVal, eVec = np.linalg.eig(coVar)
  51.    
  52.     xPca = np.dot(X, eVec)
  53.    
  54.     return xPca, eVal, eVec
  55.  
  56.  
  57. def main():
  58.     X_train = np.load('../Data/X_train.npy')
  59.     X_val = np.load('../Data/X_val.npy')
  60.    
  61.     Y_train = np.load('../Data/Y_train.npy')
  62.  
  63.     #X_train = X_train.astype(np.float64)
  64.     m = X_train.mean(axis = 0)
  65.    
  66.     X = X_train - m
  67.     #print(X_train.shape)
  68.     #print(Y_train.shape)
  69.     """
  70.    meanImage = m.reshape(28,28)
  71.    plt.axis('off')
  72.    plt.imshow(meanImage,cmap = matplotlib.cm.binary)
  73.    """
  74.     avgImages = []
  75.    
  76.     """
  77.    for i in range(10):
  78.        meanDig = X_train[Y_train == i].mean(axis=0)
  79.        print (meanDig.shape)
  80.        meanDig = meanDig.reshape(28,28)
  81.        avgImages.append(meanDig)
  82.    
  83.    plot_mnist(avgImages,5,2)
  84.    """        
  85.        
  86.     xPca, eVal, eVec = PCA(X)
  87.    
  88.     for i in range(5):
  89.         eigImg = eVec[i].reshape(28,28)
  90.         avgImages.append(eigImg)
  91.        
  92.     plot_mnist(avgImages,5,1)
  93.     """
  94.    print(eVal.shape)
  95.    numPC = np.arange(eVal.shape[0])[:25]
  96.    largest = eVal[:25]
  97.    fig, ax = plt.subplots()
  98.    ax.plot(numPC, largest, color = blue, lw = 3)
  99.    ax.set_ylabel("Eigen Values")
  100.    ax.set_xlabel("PCA Number")
  101.    ax.fill_between(numPC,0,largest, alpha = .25)
  102.    ax.set(xlim=(0, len(numPC) - 1), ylim=(0, None))
  103.    """
  104.    
  105.    
  106.    
  107.  
  108. if __name__ == '__main__':
  109.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top