Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %matplotlib inline
- import matplotlib.pyplot as plt
- n_layers = 4
- n_steps = 4
- fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 3])
- arrow = dict(width=.01, head_width=.1, color='C0')
- def plot_architecture(ax):
- for step in range(n_steps):
- for layer in range(1, n_layers):
- ax.scatter(step, layer, color='w', edgecolor='k', zorder=1000, s=100)
- for i, word in enumerate('brains are really great'.split()):
- ax.text(i, -.5, word, horizontalalignment='center', color='k')
- ax.set_aspect('equal')
- for spine in ('top', 'right', 'left', 'bottom'):
- ax.spines[spine].set_visible(False)
- ax.set_xticks([])
- ax.set_yticks([])
- ax.set_xlim(-.5, n_steps-.5)
- ax.set_ylim(-1, n_layers-.5)
- def plot_word_embedding(ax):
- plot_architecture(ax)
- for step in range(n_steps):
- for layer in range(1):
- ax.arrow(step, layer, 0, .7, **arrow)
- ax.set_title('Word Embedding')
- def plot_lstm(ax):
- plot_architecture(ax)
- # feedforward
- for step in range(n_steps):
- for layer in range(n_layers-1):
- ax.arrow(step, layer, 0, .7, **arrow)
- # recurrence
- for step in range(n_steps-1):
- for layer in range(1, n_layers):
- ax.arrow(step, layer, .7, 0, **arrow)
- ax.set_title('Causal LSTM')
- def plot_transformer(ax):
- plot_architecture(ax)
- # feedforward
- for step in range(n_steps):
- for layer in range(n_layers-1):
- ax.arrow(step, layer, 0, .7, **arrow)
- # attention
- for step in range(n_steps-1):
- for layer in range(n_layers-1):
- for reach in range(n_steps - 1):
- if (reach + step + 1) >= n_steps:
- continue
- ax.arrow(step, layer, (reach+1.) - .2, .7, **arrow)
- ax.set_title('Causal Transformer')
- plot_word_embedding(axes[0])
- plot_lstm(axes[1])
- plot_transformer(axes[2])
- fig.tight_layout()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement