Advertisement
Guest User

Untitled

a guest
Jun 27th, 2019
147
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.50 KB | None | 0 0
  1. def __init__(self, **kwargs):
  2. super(AttentionLayer, self).__init__(**kwargs)
  3.  
  4. def build(self, input_shape):
  5. assert isinstance(input_shape, list)
  6.  
  7. # Create a trainable weight variable for this layer.
  8.  
  9. self.W_a = self.add_weight(name='W_a',
  10. shape=tf.TensorShape((input_shape[0][1], input_shape[0][1])),
  11. initializer='uniform',
  12. trainable=True)
  13. self.U_a = self.add_weight(name='U_a',
  14. shape=tf.TensorShape((input_shape[1][2], input_shape[0][1])),
  15. initializer='uniform',
  16. trainable=True)
  17. self.V_a = self.add_weight(name='V_a',
  18. shape=tf.TensorShape((input_shape[0][1], 1)),
  19. initializer='uniform',
  20. trainable=True)
  21.  
  22. super(AttentionLayer, self).build(input_shape) # Be sure to call this at the end
  23.  
  24. def call(self, inputs, verbose=False):
  25. """
  26. inputs: [encoder_output_sequence, decoder_output_sequence]
  27. """
  28. assert type(inputs) == list
  29. encoder_out_seq, decoder_out_seq = inputs
  30. if verbose:
  31. print('encoder_out_seq>', encoder_out_seq.shape)
  32. print('decoder_out_seq>', decoder_out_seq.shape)
  33.  
  34. def energy_step(inputs, states):
  35. """ Step function for computing energy for a single decoder state """
  36.  
  37. assert_msg = "States must be a list. However states {} is of type {}".format(states, type(states))
  38. assert isinstance(states, list) or isinstance(states, tuple), assert_msg
  39.  
  40. """ Some parameters required for shaping tensors"""
  41. print(encoder_out_seq.shape)
  42. en_hidden = encoder_out_seq.shape[1]
  43. de_hidden = inputs.shape[-1]
  44. """ Computing S.Wa where S=[s0, s1, ..., si]"""
  45. # <= batch_size*en_seq_len, latent_dim
  46. reshaped_enc_outputs = K.reshape(encoder_out_seq, (-1, en_hidden))
  47. # <= batch_size*en_seq_len, latent_dim
  48. W_a_dot_s = K.reshape(K.dot(reshaped_enc_outputs, self.W_a), (-1,en_hidden))
  49. if verbose:
  50. print('wa.s>',W_a_dot_s.shape)
  51.  
  52. """ Computing hj.Ua """
  53. U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1) # <= batch_size, 1, latent_dim
  54. if verbose:
  55. print('Ua.h>',U_a_dot_h.shape)
  56.  
  57. """ tanh(S.Wa + hj.Ua) """
  58. # <= batch_size*en_seq_len, latent_dim
  59. reshaped_Ws_plus_Uh = K.tanh(K.reshape(W_a_dot_s + U_a_dot_h, (-1, en_hidden)))
  60. if verbose:
  61. print('Ws+Uh>', reshaped_Ws_plus_Uh.shape)
  62.  
  63. """ softmax(va.tanh(S.Wa + hj.Ua)) """
  64. # <= batch_size, en_seq_len
  65. e_i = K.reshape(K.dot(reshaped_Ws_plus_Uh, self.V_a), (-1,256))
  66.  
  67. # <= batch_size, en_seq_len
  68. e_i = K.softmax(e_i)
  69.  
  70. if verbose:
  71. print('ei>', e_i.shape)
  72.  
  73. return e_i, [e_i]
  74.  
  75. def context_step(inputs, states):
  76. """ Step function for computing ci using ei """
  77. # <= batch_size, hidden_size
  78. c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
  79. if verbose:
  80. print('ci>', c_i.shape)
  81. return c_i, [c_i]
  82.  
  83. def create_inital_state(inputs, hidden_size):
  84. print("inpuuut",inputs.shape)
  85. print(" hidden_size==", hidden_size)
  86. # We are not using initial states, but need to pass something to K.rnn funciton
  87. fake_state = K.zeros_like(inputs) # <= (batch_size, enc_seq_len, latent_dim
  88. fake_state = K.sum(fake_state, axis=1) # <= (batch_size)
  89. fake_state = K.expand_dims(fake_state) # <= (batch_size, 1)
  90. fake_state = K.tile(fake_state, [1, hidden_size]) # <= (batch_size, latent_dim
  91. return fake_state
  92.  
  93. fake_state_c = create_inital_state(encoder_out_seq, encoder_out_seq.shape[-1])
  94. fake_state_e = create_inital_state(encoder_out_seq, encoder_out_seq.shape[1]) # <= (batch_size, enc_seq_len, latent_dim)
  95.  
  96. """ Computing energy outputs """
  97. # e_outputs => (batch_size, de_seq_len, en_seq_len)
  98. last_out, e_outputs, _ = K.rnn(energy_step, decoder_out_seq, [fake_state_e],)
  99.  
  100. """ Computing context vectors """
  101. last_out, c_outputs, _ = K.rnn(
  102. context_step, e_outputs, [fake_state_c],
  103. )
  104.  
  105. return c_outputs, e_outputs
  106. def compute_output_shape(self, input_shape):
  107. """ Outputs produced by the layer """
  108. return [
  109. tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
  110. tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
  111. ]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement