Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
50
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.39 KB | None | 0 0
  1. class AttentionLayer(L.Layer):
  2. def __init__(self, name, enc_size, dec_size, hid_size, activ=tf.tanh):
  3. """ A layer that computes additive attention response and weights """
  4. super().__init__()
  5. # self.name = name
  6. self.enc_size = enc_size # num units in encoder state
  7. self.dec_size = dec_size # num units in decoder state
  8. self.hid_size = hid_size # attention layer hidden units
  9. self.activ = activ # attention layer hidden nonlinearity
  10.  
  11. self.linear_e = L.Dense(self.hid_size)
  12. self.linear_d = L.Dense(self.hid_size)
  13. self.linear_out = L.Dense(1)
  14.  
  15. def build(self, input_shape):
  16. # create layer variables
  17.  
  18. self.linear_e.build(self.enc_size)
  19. self.linear_d.build(self.dec_size)
  20. self.linear_out.build(self.hid_size)
  21.  
  22. # Hint: you can find an example of custom layer here:
  23. # https://www.tensorflow.org/tutorials/customization/custom_layers
  24.  
  25. def call(self, enc, dec, inp_mask):
  26. # def call(self, enc, dec):
  27. """
  28. Computes attention response and weights
  29. :param enc: encoder activation sequence, float32[batch_size, ninp, enc_size]
  30. :param dec: single decoder state used as "query", float32[batch_size, dec_size]
  31. :param inp_mask: mask on enc activatons (0 after first eos), float32 [batch_size, ninp]
  32. :returns: attn[batch_size, enc_size], probs[batch_size, ninp]
  33. - attn - attention response vector (weighted sum of enc)
  34. - probs - attention weights after softmax
  35. """
  36.  
  37. # Compute logits
  38. hidden_with_time_axis = tf.expand_dims(dec, 1)
  39. logits = self.linear_out(
  40. self.activ(
  41. self.linear_e(enc) + \
  42. self.linear_d(tf.tile(hidden_with_time_axis, [1, enc.shape[1], 1]))
  43. )
  44. )
  45.  
  46. # Apply mask - if mask is 0, logits should be -inf or -1e9
  47. # You may need tf.where
  48. mask_new = tf.expand_dims(inp_mask, 2)
  49. masked_logits = tf.where(tf.equal(tf.zeros_like(mask_new), mask_new), logits, -1e9 * tf.ones_like(logits))
  50.  
  51. # Compute attention probabilities (softmax)
  52. probs = tf.nn.softmax(masked_logits, axis=1)
  53.  
  54. # Compute attention response using enc and probs
  55. attn = tf.reduce_sum(probs * enc, axis=1)
  56.  
  57. return attn, probs[:, :, 0]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement