Advertisement
Guest User

Untitled

a guest
Jan 22nd, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.92 KB | None | 0 0
  1. %matplotlib inline
  2.  
  3. import matplotlib.pyplot as plt
  4. from matplotlib.gridspec import GridSpec
  5. from matplotlib.patches import Ellipse
  6. import matplotlib.patches as mpatches
  7. from scipy.stats import norm
  8. import numpy as np
  9.  
  10. # matplotlib config ################################################
  11.  
  12. plt.rc('text', usetex=True)
  13. plt.rc('text.latex', preamble=r'\usepackage{times}')
  14. plt.rc('font', family='serif')
  15. plt.rc('font', size=14)
  16. plt.close("all")
  17.  
  18. # helpers #########################################################
  19.  
  20. def values_1d(mu, sigma):
  21.     g = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
  22.     v = norm.pdf(g, mu, sigma)
  23.     return g, v
  24.  
  25.  
  26. def plot_ellipse(ax, mu, cov, nstds=[1, 2, 3], color="C0"):
  27.     def eigsorted(cov):
  28.         vals, vecs = np.linalg.eigh(cov)
  29.         order = vals.argsort()[::-1]
  30.         return vals[order], vecs[:,order]
  31.  
  32.     vals, vecs = eigsorted(cov)
  33.     theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
  34.    
  35.     for nstd in nstds:
  36.         w, h = nstd * np.sqrt(vals)
  37.         ell = Ellipse(xy=mu,
  38.                       width=w,
  39.                       height=h,
  40.                       angle=theta,
  41.                       color=color,
  42.                       alpha=0.2)
  43.         ax.add_artist(ell)
  44.        
  45.     return mu, w, h
  46.  
  47.        
  48. f = plt.figure(figsize=(16, 4))
  49.  
  50. # data to plot #####################################################
  51.  
  52. mu_1d_real = 0
  53. sigma_1d_real = 1
  54.  
  55. mu_1d_fake = 1.1
  56. sigma_1d_fake = 1.1
  57.  
  58. mu_2d_real = [0, 0]
  59. mu_2d_fake = [0, 0]
  60.  
  61. sigma_2d_real = [[1, 0], [0, 1]]
  62. sigma_2d_fake = [[1.2, 0], [0.4, 1]]
  63.  
  64. emb_2_1 = np.random.randn(100, 1), np.random.randn(100, 1)
  65. emb_2_3 = np.random.randn(100, 1), np.random.randn(100, 1)
  66. emb_2_13 = np.random.randn(100, 1), np.random.randn(100, 1)
  67.  
  68. emb_1_23 = np.random.randn(100, 1), np.random.randn(100, 1)
  69. emb_1_2 = np.random.randn(100, 1), np.random.randn(100, 1)
  70. emb_1_3 = np.random.randn(100, 1), np.random.randn(100, 1)
  71.  
  72. # 1d plot ##########################################################
  73.  
  74. ax1 = plt.subplot(1, 4, 1)
  75. ax1.set_yticklabels([])
  76. ax1.set_xticklabels([])
  77. ax1.tick_params(direction="in")
  78. plt.title("a) $r=(1, 0, 0)$, $a=(0, 1, 1)$")
  79.  
  80. plt.plot(*values_1d(mu_1d_real, sigma_1d_real), lw=3, label="real")
  81. plt.plot(*values_1d(mu_1d_fake, sigma_1d_fake), lw=3, label="generated")
  82. plt.legend(loc=1)
  83.  
  84. # 2d plot ##########################################################
  85.  
  86. ax2 = plt.subplot(1, 4, 2)
  87. ax2.set_yticklabels([])
  88. ax2.set_xticklabels([])
  89. ax2.tick_params(direction="in")
  90. plt.title("b) $r=(1, 1, 0)$, $a=(0, 0, 1)$")
  91.  
  92. c1, w1, h1 = plot_ellipse(ax2, mu=mu_2d_real, cov=sigma_2d_real, color="C0")
  93. c2, w2, h2 = plot_ellipse(ax2, mu=mu_2d_fake, cov=sigma_2d_fake, color="C1")
  94. real_patch = mpatches.Patch(color='C0', alpha=0.2, label='real')
  95. fake_patch = mpatches.Patch(color='C1', alpha=0.2, label='fake')
  96. plt.legend(handles=[real_patch, fake_patch], loc=1)
  97.  
  98. xmin = min(mu_2d_real[0] - w1, mu_2d_fake[0] - w2) * 0.7
  99. xmax = max(mu_2d_real[0] + w1, mu_2d_fake[0] + w2) * 0.7
  100.  
  101. ymin = min(mu_2d_real[1] - h1, mu_2d_fake[1] - h2) * 0.7
  102. ymax = max(mu_2d_real[1] + h1, mu_2d_fake[1] + h2) * 0.7
  103.  
  104. plt.xlim(xmin, xmax)
  105. plt.ylim(ymin, ymax)
  106.  
  107. # embeddings r=2 ###################################################
  108.  
  109. ax3 = plt.subplot(1, 4, 3)
  110. ax3.set_yticklabels([])
  111. ax3.set_xticklabels([])
  112. ax3.tick_params(direction="in")
  113. plt.title("c) $r=(0,1,0)$")
  114.  
  115. plt.plot(*emb_2_1, '.', label="$2|1$")
  116. plt.plot(*emb_2_3, '.', label="$2|3$")
  117. plt.plot(*emb_2_13, '.', label="$2|1,3$")
  118. plt.legend(loc=4)
  119.  
  120. # embeddings r=1 ###################################################
  121.  
  122. ax4 = plt.subplot(1, 4, 4)
  123. ax4.set_yticklabels([])
  124. ax4.set_xticklabels([])
  125. plt.title("d) $r=(1,0,0)$")
  126.  
  127. plt.plot(*emb_1_23, '.', label="$1|2,3$")
  128. plt.plot(*emb_1_2, '.', label="$1|2$")
  129. plt.plot(*emb_1_3, '.', label="$1|3$")
  130. plt.legend(loc=4)
  131.  
  132. # finish up ########################################################
  133.  
  134. plt.tight_layout(1, 1, 1)
  135. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement