Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn.manifold import TSNE
- from mpl_toolkits.mplot3d import Axes3D
- %matplotlib qt
- from IPython import display
- import matplotlib.cm as cmx
- import matplotlib.colors as colors
- def get_cmap(N):
- '''Returns a function that maps each index in 0, 1, ... N-1 to a distinct
- RGB color.'''
- color_norm = colors.Normalize(vmin=0, vmax=N-1)
- scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
- def map_index_to_rgb_color(index):
- return scalar_map.to_rgba(index)
- return map_index_to_rgb_color
- def plot_latent_space(x_batch, y_batch, iteration=None, dim=2):
- model = TSNE(n_components=dim, random_state=0, perplexity=50, learning_rate=500, n_iter=200)
- z_mu = model.fit_transform(mu.eval(feed_dict={X: x_batch}))
- n_classes = len(list(set(np.argmax(y_batch, 1))))
- cmap = get_cmap(n_classes)
- fig = plt.figure(2, figsize=(8,8))
- if dim is 3:
- for i in list(set(np.argmax(y_batch, 1))):
- bx = fig.add_subplot(111, projection='3d')
- index = np.where(np.argmax(y_batch, 1) == i)
- xs = z_mu[index, 0]
- ys = z_mu[index, 1:]
- zs = z_mu[index, 2]
- bx.scatter(xs, ys, zs,c=cmap(i), label=str(i))
- else:
- for i in list(set(np.argmax(y_batch, 1))):
- bx = fig.add_subplot(111)
- index = np.where(np.argmax(y_batch, 1) == i)
- xs = z_mu[index, 0]
- ys = z_mu[index, 1]
- bx.scatter(xs, ys, c=cmap(i), label=str(i))
- bx.set_xlabel('X Label')
- bx.set_ylabel('Y Label')
- bx.legend()
- bx.set_title('Truth')
- if iteration is None:
- plt.savefig('latent_space.png')
- else:
- plt.savefig('latent_space' + str(iteration) + '.png')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement