Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class AttentionLayer(L.Layer):
- def __init__(self, name, enc_size, dec_size, hid_size, activ=tf.tanh):
- """ A layer that computes additive attention response and weights """
- super().__init__()
- # self.name = name
- self.enc_size = enc_size # num units in encoder state
- self.dec_size = dec_size # num units in decoder state
- self.hid_size = hid_size # attention layer hidden units
- self.activ = activ # attention layer hidden nonlinearity
- self.linear_e = L.Dense(self.hid_size)
- self.linear_d = L.Dense(self.hid_size)
- self.linear_out = L.Dense(1)
- def build(self, input_shape):
- # create layer variables
- self.linear_e.build(self.enc_size)
- self.linear_d.build(self.dec_size)
- self.linear_out.build(self.hid_size)
- # Hint: you can find an example of custom layer here:
- # https://www.tensorflow.org/tutorials/customization/custom_layers
- def call(self, enc, dec, inp_mask):
- # def call(self, enc, dec):
- """
- Computes attention response and weights
- :param enc: encoder activation sequence, float32[batch_size, ninp, enc_size]
- :param dec: single decoder state used as "query", float32[batch_size, dec_size]
- :param inp_mask: mask on enc activatons (0 after first eos), float32 [batch_size, ninp]
- :returns: attn[batch_size, enc_size], probs[batch_size, ninp]
- - attn - attention response vector (weighted sum of enc)
- - probs - attention weights after softmax
- """
- # Compute logits
- hidden_with_time_axis = tf.expand_dims(dec, 1)
- logits = self.linear_out(
- self.activ(
- self.linear_e(enc) + \
- self.linear_d(tf.tile(hidden_with_time_axis, [1, enc.shape[1], 1]))
- )
- )
- # Apply mask - if mask is 0, logits should be -inf or -1e9
- # You may need tf.where
- mask_new = tf.expand_dims(inp_mask, 2)
- masked_logits = tf.where(tf.equal(tf.zeros_like(mask_new), mask_new), logits, -1e9 * tf.ones_like(logits))
- # Compute attention probabilities (softmax)
- probs = tf.nn.softmax(masked_logits, axis=1)
- # Compute attention response using enc and probs
- attn = tf.reduce_sum(probs * enc, axis=1)
- return attn, probs[:, :, 0]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement