Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %matplotlib inline
- import matplotlib.pyplot as plt
- from matplotlib.gridspec import GridSpec
- from matplotlib.patches import Ellipse
- import matplotlib.patches as mpatches
- from scipy.stats import norm
- import numpy as np
- # matplotlib config ################################################
- plt.rc('text', usetex=True)
- plt.rc('text.latex', preamble=r'\usepackage{times}')
- plt.rc('font', family='serif')
- plt.rc('font', size=14)
- plt.close("all")
- # helpers #########################################################
- def values_1d(mu, sigma):
- g = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
- v = norm.pdf(g, mu, sigma)
- return g, v
- def plot_ellipse(ax, mu, cov, nstds=[1, 2, 3], color="C0"):
- def eigsorted(cov):
- vals, vecs = np.linalg.eigh(cov)
- order = vals.argsort()[::-1]
- return vals[order], vecs[:,order]
- vals, vecs = eigsorted(cov)
- theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
- for nstd in nstds:
- w, h = nstd * np.sqrt(vals)
- ell = Ellipse(xy=mu,
- width=w,
- height=h,
- angle=theta,
- color=color,
- alpha=0.2)
- ax.add_artist(ell)
- return mu, w, h
- f = plt.figure(figsize=(16, 4))
- # data to plot #####################################################
- mu_1d_real = 0
- sigma_1d_real = 1
- mu_1d_fake = 1.1
- sigma_1d_fake = 1.1
- mu_2d_real = [0, 0]
- mu_2d_fake = [0, 0]
- sigma_2d_real = [[1, 0], [0, 1]]
- sigma_2d_fake = [[1.2, 0], [0.4, 1]]
- emb_2_1 = np.random.randn(100, 1), np.random.randn(100, 1)
- emb_2_3 = np.random.randn(100, 1), np.random.randn(100, 1)
- emb_2_13 = np.random.randn(100, 1), np.random.randn(100, 1)
- emb_1_23 = np.random.randn(100, 1), np.random.randn(100, 1)
- emb_1_2 = np.random.randn(100, 1), np.random.randn(100, 1)
- emb_1_3 = np.random.randn(100, 1), np.random.randn(100, 1)
- # 1d plot ##########################################################
- ax1 = plt.subplot(1, 4, 1)
- ax1.set_yticklabels([])
- ax1.set_xticklabels([])
- ax1.tick_params(direction="in")
- plt.title("a) $r=(1, 0, 0)$, $a=(0, 1, 1)$")
- plt.plot(*values_1d(mu_1d_real, sigma_1d_real), lw=3, label="real")
- plt.plot(*values_1d(mu_1d_fake, sigma_1d_fake), lw=3, label="generated")
- plt.legend(loc=1)
- # 2d plot ##########################################################
- ax2 = plt.subplot(1, 4, 2)
- ax2.set_yticklabels([])
- ax2.set_xticklabels([])
- ax2.tick_params(direction="in")
- plt.title("b) $r=(1, 1, 0)$, $a=(0, 0, 1)$")
- c1, w1, h1 = plot_ellipse(ax2, mu=mu_2d_real, cov=sigma_2d_real, color="C0")
- c2, w2, h2 = plot_ellipse(ax2, mu=mu_2d_fake, cov=sigma_2d_fake, color="C1")
- real_patch = mpatches.Patch(color='C0', alpha=0.2, label='real')
- fake_patch = mpatches.Patch(color='C1', alpha=0.2, label='fake')
- plt.legend(handles=[real_patch, fake_patch], loc=1)
- xmin = min(mu_2d_real[0] - w1, mu_2d_fake[0] - w2) * 0.7
- xmax = max(mu_2d_real[0] + w1, mu_2d_fake[0] + w2) * 0.7
- ymin = min(mu_2d_real[1] - h1, mu_2d_fake[1] - h2) * 0.7
- ymax = max(mu_2d_real[1] + h1, mu_2d_fake[1] + h2) * 0.7
- plt.xlim(xmin, xmax)
- plt.ylim(ymin, ymax)
- # embeddings r=2 ###################################################
- ax3 = plt.subplot(1, 4, 3)
- ax3.set_yticklabels([])
- ax3.set_xticklabels([])
- ax3.tick_params(direction="in")
- plt.title("c) $r=(0,1,0)$")
- plt.plot(*emb_2_1, '.', label="$2|1$")
- plt.plot(*emb_2_3, '.', label="$2|3$")
- plt.plot(*emb_2_13, '.', label="$2|1,3$")
- plt.legend(loc=4)
- # embeddings r=1 ###################################################
- ax4 = plt.subplot(1, 4, 4)
- ax4.set_yticklabels([])
- ax4.set_xticklabels([])
- plt.title("d) $r=(1,0,0)$")
- plt.plot(*emb_1_23, '.', label="$1|2,3$")
- plt.plot(*emb_1_2, '.', label="$1|2$")
- plt.plot(*emb_1_3, '.', label="$1|3$")
- plt.legend(loc=4)
- # finish up ########################################################
- plt.tight_layout(1, 1, 1)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement