Advertisement
Guest User

Untitled

a guest
Jul 18th, 2019
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.53 KB | None | 0 0
  1. class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
  2. def __init__(self,
  3. embedding_matrices: List[tf.keras.layers.Embedding],
  4. output_dim: Optional[int] = None,
  5. n_lstm_units: int = 2,
  6. name: str = 'contextual_dynamic_meta_embedding',
  7. **kwargs):
  8. """
  9. :param embedding_matrices: List of embedding layers
  10. :param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
  11. :param output_dim: Dimension of the output embedding
  12. :param name: Layer name
  13. """
  14.  
  15. super().__init__(name=name, **kwargs)
  16.  
  17. # Validate all the embedding matrices have the same vocabulary size
  18. if not len(set((e.input_dim for e in embedding_matrices))) == 1:
  19. raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
  20.  
  21. # If no output_dim is supplied, use the maximum dimension from the given matrices
  22. self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])
  23.  
  24. self.n_lstm_units = n_lstm_units
  25.  
  26. self.embedding_matrices = embedding_matrices
  27. self.n_embeddings = len(self.embedding_matrices)
  28.  
  29. self.projections = [tf.keras.layers.Dense(units=self.output_dim,
  30. activation=None,
  31. name='projection_{}'.format(i),
  32. dtype=self.dtype) for i, e in enumerate(self.embedding_matrices)]
  33.  
  34. self.bilstm = tf.keras.layers.Bidirectional(
  35. tf.keras.layers.LSTM(units=self.n_lstm_units, return_sequences=True),
  36. name='bilstm',
  37. dtype=self.dtype)
  38.  
  39. self.attention = tf.keras.layers.Dense(units=1,
  40. activation=None,
  41. name='attention',
  42. dtype=self.dtype)
  43.  
  44. def call(self, inputs,
  45. **kwargs) -> tf.Tensor:
  46. batch_size, time_steps = inputs.shape[:2]
  47.  
  48. # Embedding lookup
  49. embedded = [e(inputs) for e in self.embedding_matrices] # List of shape=(batch_size, time_steps, channels_i)
  50.  
  51. # Projection
  52. projected = tf.reshape(tf.concat([p(e) for p, e in zip(self.projections, embedded)], axis=-1),
  53. # Project embeddings
  54. shape=(batch_size, time_steps, -1, self.output_dim),
  55. name='projected') # shape=(batch_size, time_steps, n_embeddings, output_dim)
  56.  
  57. # Contextualize
  58. context = self.bilstm(
  59. tf.reshape(projected, shape=(batch_size * self.n_embeddings, time_steps,
  60. self.output_dim))) # shape=(batch_size * n_embeddings, time_steps, n_lstm_units*2)
  61. context = tf.reshape(context, shape=(batch_size, time_steps, self.n_embeddings,
  62. self.n_lstm_units * 2)) # shape=(batch_size, time_steps, n_embeddings, n_lstm_units*2)
  63.  
  64. # Calculate attention coefficients
  65. alphas = self.attention(context) # shape=(batch_size, time_steps, n_embeddings, 1)
  66. alphas = tf.nn.softmax(alphas, axis=-2) # shape=(batch_size, time_steps, n_embeddings, 1)
  67.  
  68. # Attend
  69. output = tf.squeeze(tf.matmul(
  70. tf.transpose(projected, perm=[0, 1, 3, 2]), alphas), # Attending
  71. name='output') # shape=(batch_size, time_steps, output_dim)
  72.  
  73. return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement