Guest User

Untitled

a guest
Sep 20th, 2018
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.29 KB | None | 0 0
  1. from os.path import join, basename, splitext
  2. import argparse
  3.  
  4. import numpy as np
  5. from sklearn.decomposition import PCA
  6. import matplotlib.pyplot as plt
  7.  
  8. # setup parser
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument("X", help="filename of the feature file (`.npy`) to visualize")
  11. parser.add_argument("y", help="filename of the label file (`.csv` or `.npy`) to visualize classes")
  12. parser.add_argument("out_fn", help="filename for the outputing image (`.pdf`)")
  13. args = parser.parse_args()
  14.  
  15. # load the feature file
  16. X = np.load(args.X)
  17.  
  18. # load the label file
  19. ext = splitext(args.y)[1]
  20. if ext == '.csv':
  21. with open(args.y) as f:
  22. y = np.array([l.split('\n')[0] for l in f])
  23. elif ext == '.npy':
  24. y = np.load(args.y)
  25. else:
  26. raise NotImplementedError('{} is not supported!'.format(ext))
  27.  
  28. # check shape
  29. if X.shape[0] != len(y):
  30. raise ValueError('Feature & label should have same number of samples!')
  31.  
  32. # run the PCA
  33. pca = PCA(2)
  34. z = pca.fit_transform(X)
  35.  
  36. # markers
  37. markers = ['o', '.', ',', 'x', '+', 'v', '^', '<', '>', 's', 'd']
  38.  
  39. # visualize per label
  40. for k, label in enumerate(set(y)):
  41. idx = np.where(y == label)[0]
  42. plt.scatter(z[idx, 0], z[idx, 1], label=label,
  43. marker=markers[(len(markers) % (k + 1)) - 1])
  44.  
  45. # save fig
  46. plt.legend()
  47. plt.tight_layout()
  48. plt.savefig(args.out_fn)
Add Comment
Please, Sign In to add comment