Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn.manifold import TSNE
- from mpl_toolkits.mplot3d import Axes3D
- import matplotlib.cm as cm
- def tsne_scatter_3d(look_up_list, title):
- embeds_selected = [i[2].detach().numpy() for i in look_up_list]
- tsne_model = TSNE(n_components =3, random_state = 0)
- emb_in_3d = tsne_model.fit_transform(embeds_selected)
- list_xs = emb_in_3d[:, 0]
- list_xs = [list_xs[i:i + 10] for i in range(0, len(list_xs), 10)]
- #list_xs
- list_ys = emb_in_3d[:, 1]
- list_ys = [list_ys[i:i + 10] for i in range(0, len(list_ys), 10)]
- list_zs = emb_in_3d[:, 2]
- list_zs = [list_zs[i:i + 10] for i in range(0, len(list_zs), 10)]
- colors =cm.rainbow(np.linspace(0, 1, len(list_ys)))
- #fig = plt.figure(figsize= (8, 6)) ##If bigger figuer size is needed
- fig = plt.figure()
- ax = Axes3D(fig)
- for x, y, z, c in zip(list_xs, list_ys, list_zs, colors):
- ax.scatter(x, y, z, color=c)
- plt.title(title)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement