SHARE
TWEET

Untitled

a guest Jul 18th, 2019 75 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
  2.     def __init__(self,
  3.                  embedding_matrices: List[tf.keras.layers.Embedding],
  4.                  output_dim: Optional[int] = None,
  5.                  n_lstm_units: int = 2,
  6.                  name: str = 'contextual_dynamic_meta_embedding',
  7.                  **kwargs):
  8.         """
  9.         :param embedding_matrices: List of embedding layers
  10.         :param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
  11.         :param output_dim: Dimension of the output embedding
  12.         :param name: Layer name
  13.         """
  14.  
  15.         super().__init__(name=name, **kwargs)
  16.  
  17.         # Validate all the embedding matrices have the same vocabulary size
  18.         if not len(set((e.input_dim for e in embedding_matrices))) == 1:
  19.             raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
  20.  
  21.         # If no output_dim is supplied, use the maximum dimension from the given matrices
  22.         self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])
  23.  
  24.         self.n_lstm_units = n_lstm_units
  25.  
  26.         self.embedding_matrices = embedding_matrices
  27.         self.n_embeddings = len(self.embedding_matrices)
  28.  
  29.         self.projections = [tf.keras.layers.Dense(units=self.output_dim,
  30.                                                   activation=None,
  31.                                                   name='projection_{}'.format(i),
  32.                                                   dtype=self.dtype) for i, e in enumerate(self.embedding_matrices)]
  33.  
  34.         self.bilstm = tf.keras.layers.Bidirectional(
  35.             tf.keras.layers.LSTM(units=self.n_lstm_units, return_sequences=True),
  36.             name='bilstm',
  37.             dtype=self.dtype)
  38.  
  39.         self.attention = tf.keras.layers.Dense(units=1,
  40.                                                activation=None,
  41.                                                name='attention',
  42.                                                dtype=self.dtype)
  43.  
  44.     def call(self, inputs,
  45.              **kwargs) -> tf.Tensor:
  46.         batch_size, time_steps = inputs.shape[:2]
  47.  
  48.         # Embedding lookup
  49.         embedded = [e(inputs) for e in self.embedding_matrices]  # List of shape=(batch_size, time_steps, channels_i)
  50.  
  51.         # Projection
  52.         projected = tf.reshape(tf.concat([p(e) for p, e in zip(self.projections, embedded)], axis=-1),
  53.                                # Project embeddings
  54.                                shape=(batch_size, time_steps, -1, self.output_dim),
  55.                                name='projected')  # shape=(batch_size, time_steps, n_embeddings, output_dim)
  56.  
  57.         # Contextualize
  58.         context = self.bilstm(
  59.             tf.reshape(projected, shape=(batch_size * self.n_embeddings, time_steps,
  60.                                          self.output_dim)))  # shape=(batch_size * n_embeddings, time_steps, n_lstm_units*2)
  61.         context = tf.reshape(context, shape=(batch_size, time_steps, self.n_embeddings,
  62.                                              self.n_lstm_units * 2))  # shape=(batch_size, time_steps, n_embeddings, n_lstm_units*2)
  63.  
  64.         # Calculate attention coefficients
  65.         alphas = self.attention(context)  # shape=(batch_size, time_steps, n_embeddings, 1)
  66.         alphas = tf.nn.softmax(alphas, axis=-2)  # shape=(batch_size, time_steps, n_embeddings, 1)
  67.  
  68.         # Attend
  69.         output = tf.squeeze(tf.matmul(
  70.             tf.transpose(projected, perm=[0, 1, 3, 2]), alphas),  # Attending
  71.             name='output')  # shape=(batch_size, time_steps, output_dim)
  72.  
  73.         return output
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top