Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class ContextualDynamicMetaEmbedding(tf.keras.layers.Layer):
- def __init__(self,
- embedding_matrices: List[tf.keras.layers.Embedding],
- output_dim: Optional[int] = None,
- n_lstm_units: int = 2,
- name: str = 'contextual_dynamic_meta_embedding',
- **kwargs):
- """
- :param embedding_matrices: List of embedding layers
- :param n_lstm_units: Number of units in each LSTM, (notated as `m` in the original article)
- :param output_dim: Dimension of the output embedding
- :param name: Layer name
- """
- super().__init__(name=name, **kwargs)
- # Validate all the embedding matrices have the same vocabulary size
- if not len(set((e.input_dim for e in embedding_matrices))) == 1:
- raise ValueError('Vocabulary sizes (first dimension) of all embedding matrices must match')
- # If no output_dim is supplied, use the maximum dimension from the given matrices
- self.output_dim = output_dim or min([e.output_dim for e in embedding_matrices])
- self.n_lstm_units = n_lstm_units
- self.embedding_matrices = embedding_matrices
- self.n_embeddings = len(self.embedding_matrices)
- self.projections = [tf.keras.layers.Dense(units=self.output_dim,
- activation=None,
- name='projection_{}'.format(i),
- dtype=self.dtype) for i, e in enumerate(self.embedding_matrices)]
- self.bilstm = tf.keras.layers.Bidirectional(
- tf.keras.layers.LSTM(units=self.n_lstm_units, return_sequences=True),
- name='bilstm',
- dtype=self.dtype)
- self.attention = tf.keras.layers.Dense(units=1,
- activation=None,
- name='attention',
- dtype=self.dtype)
- def call(self, inputs,
- **kwargs) -> tf.Tensor:
- batch_size, time_steps = inputs.shape[:2]
- # Embedding lookup
- embedded = [e(inputs) for e in self.embedding_matrices] # List of shape=(batch_size, time_steps, channels_i)
- # Projection
- projected = tf.reshape(tf.concat([p(e) for p, e in zip(self.projections, embedded)], axis=-1),
- # Project embeddings
- shape=(batch_size, time_steps, -1, self.output_dim),
- name='projected') # shape=(batch_size, time_steps, n_embeddings, output_dim)
- # Contextualize
- context = self.bilstm(
- tf.reshape(projected, shape=(batch_size * self.n_embeddings, time_steps,
- self.output_dim))) # shape=(batch_size * n_embeddings, time_steps, n_lstm_units*2)
- context = tf.reshape(context, shape=(batch_size, time_steps, self.n_embeddings,
- self.n_lstm_units * 2)) # shape=(batch_size, time_steps, n_embeddings, n_lstm_units*2)
- # Calculate attention coefficients
- alphas = self.attention(context) # shape=(batch_size, time_steps, n_embeddings, 1)
- alphas = tf.nn.softmax(alphas, axis=-2) # shape=(batch_size, time_steps, n_embeddings, 1)
- # Attend
- output = tf.squeeze(tf.matmul(
- tf.transpose(projected, perm=[0, 1, 3, 2]), alphas), # Attending
- name='output') # shape=(batch_size, time_steps, output_dim)
- return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement