Advertisement
Guest User

Untitled

a guest
Oct 16th, 2018
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.78 KB | None | 0 0
  1. import matplotlib.pyplot as plt
  2. import mnist
  3. import numpy as np
  4.  
  5. np.random.seed(1234)
  6.  
  7. plt.rcParams["figure.figsize"] = [16, 9]
  8.  
  9.  
  10. def tiles(examples):
  11. rows_count = examples.shape[0]
  12. cols_count = examples.shape[1]
  13. tile_height = examples.shape[2]
  14. tile_width = examples.shape[3]
  15.  
  16. space_between_tiles = 2
  17. img_matrix = np.empty(shape=( (tile_height + space_between_tiles) * rows_count,(tile_width + space_between_tiles) * cols_count ))
  18. img_matrix.fill(np.nan)
  19.  
  20. for row in range(rows_count):
  21. for col in range(cols_count):
  22. img_matrix[row*(tile_height + space_between_tiles):row*(tile_height + space_between_tiles)+tile_height, col*(tile_width + space_between_tiles):col*(tile_width + space_between_tiles)+tile_width] = examples[row, col]
  23.  
  24.  
  25. #raise Exception("Not implemented!")
  26.  
  27. return img_matrix
  28.  
  29. def plot_2d_mnist_scatter(X, y):
  30. fig, plot = plt.subplots()
  31. fig.set_size_inches(16, 16)
  32. plt.prism()
  33. y.shape
  34. for i in range(10):
  35. digit_indeces = X[i==y]
  36. dim1 = digit_indeces[:, 0]
  37. dim2 = digit_indeces[:, 1]
  38. plot.scatter(dim1, dim2, label=i)
  39.  
  40.  
  41. plot.set_xticks(())
  42. plot.set_yticks(())
  43.  
  44. plt.tight_layout()
  45. plt.legend()
  46. plt.show()
  47.  
  48.  
  49.  
  50. X = mnist.train_images().astype(np.float32) / 255.0
  51. y = mnist.train_labels()
  52.  
  53. print(X.shape)
  54. X = np.reshape(X,(X.shape[0],-1))
  55. X = X[:10000,:]
  56. y = y[:10000]
  57. print(X.shape)
  58.  
  59. from sklearn.decomposition import PCA
  60.  
  61. pca = PCA(n_components=20)
  62. X_pca = pca.fit(X).transform(X)
  63.  
  64. #print(X_pca.shape)
  65.  
  66. #plot_2d_mnist_scatter(X_pca,y)
  67.  
  68. from sklearn.manifold import TSNE
  69.  
  70. tsne = TSNE(n_iter=300)
  71. x_tsne = TSNE(n_components=2).fit_transform(X)
  72.  
  73.  
  74. print(x_tsne.shape)
  75. plot_2d_mnist_scatter(x_tsne,y)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement