Advertisement
Guest User

Untitled

a guest
Oct 22nd, 2019
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.91 KB | None | 0 0
  1. %matplotlib inline
  2. import matplotlib.pyplot as plt
  3.  
  4.  
  5. n_layers = 4
  6. n_steps = 4
  7.  
  8.  
  9. fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 3])
  10.  
  11. arrow = dict(width=.01, head_width=.1, color='C0')
  12.  
  13. def plot_architecture(ax):
  14. for step in range(n_steps):
  15. for layer in range(1, n_layers):
  16. ax.scatter(step, layer, color='w', edgecolor='k', zorder=1000, s=100)
  17. for i, word in enumerate('brains are really great'.split()):
  18. ax.text(i, -.5, word, horizontalalignment='center', color='k')
  19.  
  20. ax.set_aspect('equal')
  21. for spine in ('top', 'right', 'left', 'bottom'):
  22. ax.spines[spine].set_visible(False)
  23. ax.set_xticks([])
  24. ax.set_yticks([])
  25. ax.set_xlim(-.5, n_steps-.5)
  26. ax.set_ylim(-1, n_layers-.5)
  27.  
  28. def plot_word_embedding(ax):
  29. plot_architecture(ax)
  30. for step in range(n_steps):
  31. for layer in range(1):
  32. ax.arrow(step, layer, 0, .7, **arrow)
  33. ax.set_title('Word Embedding')
  34.  
  35. def plot_lstm(ax):
  36. plot_architecture(ax)
  37. # feedforward
  38. for step in range(n_steps):
  39. for layer in range(n_layers-1):
  40. ax.arrow(step, layer, 0, .7, **arrow)
  41.  
  42. # recurrence
  43. for step in range(n_steps-1):
  44. for layer in range(1, n_layers):
  45. ax.arrow(step, layer, .7, 0, **arrow)
  46.  
  47. ax.set_title('Causal LSTM')
  48.  
  49. def plot_transformer(ax):
  50. plot_architecture(ax)
  51.  
  52. # feedforward
  53. for step in range(n_steps):
  54. for layer in range(n_layers-1):
  55. ax.arrow(step, layer, 0, .7, **arrow)
  56.  
  57. # attention
  58. for step in range(n_steps-1):
  59. for layer in range(n_layers-1):
  60. for reach in range(n_steps - 1):
  61. if (reach + step + 1) >= n_steps:
  62. continue
  63. ax.arrow(step, layer, (reach+1.) - .2, .7, **arrow)
  64. ax.set_title('Causal Transformer')
  65.  
  66. plot_word_embedding(axes[0])
  67. plot_lstm(axes[1])
  68. plot_transformer(axes[2])
  69. fig.tight_layout()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement