Advertisement
Guest User

Untitled

a guest
Apr 26th, 2017
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.71 KB | None | 0 0
  1. from sklearn.manifold import TSNE
  2. from mpl_toolkits.mplot3d import Axes3D
  3. %matplotlib qt
  4. from IPython import display
  5. import matplotlib.cm as cmx
  6. import matplotlib.colors as colors
  7.  
  8. def get_cmap(N):
  9. '''Returns a function that maps each index in 0, 1, ... N-1 to a distinct
  10. RGB color.'''
  11. color_norm = colors.Normalize(vmin=0, vmax=N-1)
  12. scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
  13. def map_index_to_rgb_color(index):
  14. return scalar_map.to_rgba(index)
  15. return map_index_to_rgb_color
  16.  
  17. def plot_latent_space(x_batch, y_batch, iteration=None, dim=2):
  18.  
  19. model = TSNE(n_components=dim, random_state=0, perplexity=50, learning_rate=500, n_iter=200)
  20. z_mu = model.fit_transform(mu.eval(feed_dict={X: x_batch}))
  21. n_classes = len(list(set(np.argmax(y_batch, 1))))
  22. cmap = get_cmap(n_classes)
  23. fig = plt.figure(2, figsize=(8,8))
  24.  
  25. if dim is 3:
  26. for i in list(set(np.argmax(y_batch, 1))):
  27. bx = fig.add_subplot(111, projection='3d')
  28.  
  29. index = np.where(np.argmax(y_batch, 1) == i)
  30. xs = z_mu[index, 0]
  31. ys = z_mu[index, 1:]
  32. zs = z_mu[index, 2]
  33. bx.scatter(xs, ys, zs,c=cmap(i), label=str(i))
  34. else:
  35. for i in list(set(np.argmax(y_batch, 1))):
  36. bx = fig.add_subplot(111)
  37. index = np.where(np.argmax(y_batch, 1) == i)
  38. xs = z_mu[index, 0]
  39. ys = z_mu[index, 1]
  40. bx.scatter(xs, ys, c=cmap(i), label=str(i))
  41.  
  42. bx.set_xlabel('X Label')
  43. bx.set_ylabel('Y Label')
  44. bx.legend()
  45. bx.set_title('Truth')
  46. if iteration is None:
  47. plt.savefig('latent_space.png')
  48. else:
  49. plt.savefig('latent_space' + str(iteration) + '.png')
  50. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement