SHARE
TWEET

Untitled

a guest Jun 27th, 2019 78 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def __init__(self, **kwargs):
  2.     super(AttentionLayer, self).__init__(**kwargs)
  3.  
  4. def build(self, input_shape):
  5.     assert isinstance(input_shape, list)
  6.  
  7.     # Create a trainable weight variable for this layer.
  8.  
  9.     self.W_a = self.add_weight(name='W_a',
  10.                                shape=tf.TensorShape((input_shape[0][1], input_shape[0][1])),
  11.                                initializer='uniform',
  12.                                trainable=True)
  13.     self.U_a = self.add_weight(name='U_a',
  14.                                shape=tf.TensorShape((input_shape[1][2], input_shape[0][1])),
  15.                                initializer='uniform',
  16.                                trainable=True)
  17.     self.V_a = self.add_weight(name='V_a',
  18.                                shape=tf.TensorShape((input_shape[0][1], 1)),
  19.                                initializer='uniform',
  20.                                trainable=True)
  21.  
  22.     super(AttentionLayer, self).build(input_shape)  # Be sure to call this at the end
  23.  
  24. def call(self, inputs, verbose=False):
  25.     """
  26.     inputs: [encoder_output_sequence, decoder_output_sequence]
  27.     """
  28.     assert type(inputs) == list
  29.     encoder_out_seq, decoder_out_seq = inputs
  30.     if verbose:
  31.         print('encoder_out_seq>', encoder_out_seq.shape)
  32.         print('decoder_out_seq>', decoder_out_seq.shape)
  33.  
  34.     def energy_step(inputs, states):
  35.         """ Step function for computing energy for a single decoder state """
  36.  
  37.         assert_msg = "States must be a list. However states {} is of type {}".format(states, type(states))
  38.         assert isinstance(states, list) or isinstance(states, tuple), assert_msg
  39.  
  40.         """ Some parameters required for shaping tensors"""
  41.         print(encoder_out_seq.shape)
  42.         en_hidden = encoder_out_seq.shape[1]
  43.         de_hidden = inputs.shape[-1]
  44.         """ Computing S.Wa where S=[s0, s1, ..., si]"""
  45.         # <= batch_size*en_seq_len, latent_dim
  46.         reshaped_enc_outputs = K.reshape(encoder_out_seq, (-1, en_hidden))
  47.         # <= batch_size*en_seq_len, latent_dim
  48.         W_a_dot_s = K.reshape(K.dot(reshaped_enc_outputs, self.W_a), (-1,en_hidden))
  49.         if verbose:
  50.             print('wa.s>',W_a_dot_s.shape)
  51.  
  52.         """ Computing hj.Ua """
  53.         U_a_dot_h = K.expand_dims(K.dot(inputs, self.U_a), 1)  # <= batch_size, 1, latent_dim
  54.         if verbose:
  55.             print('Ua.h>',U_a_dot_h.shape)
  56.  
  57.         """ tanh(S.Wa + hj.Ua) """
  58.         # <= batch_size*en_seq_len, latent_dim
  59.         reshaped_Ws_plus_Uh = K.tanh(K.reshape(W_a_dot_s + U_a_dot_h, (-1, en_hidden)))
  60.         if verbose:
  61.             print('Ws+Uh>', reshaped_Ws_plus_Uh.shape)
  62.  
  63.         """ softmax(va.tanh(S.Wa + hj.Ua)) """
  64.         # <= batch_size, en_seq_len
  65.         e_i = K.reshape(K.dot(reshaped_Ws_plus_Uh, self.V_a), (-1,256))
  66.  
  67.         # <= batch_size, en_seq_len
  68.         e_i = K.softmax(e_i)
  69.  
  70.         if verbose:
  71.             print('ei>', e_i.shape)
  72.  
  73.         return e_i, [e_i]
  74.  
  75.     def context_step(inputs, states):
  76.         """ Step function for computing ci using ei """
  77.         # <= batch_size, hidden_size
  78.         c_i = K.sum(encoder_out_seq * K.expand_dims(inputs, -1), axis=1)
  79.         if verbose:
  80.             print('ci>', c_i.shape)
  81.         return c_i, [c_i]
  82.  
  83.     def create_inital_state(inputs, hidden_size):
  84.         print("inpuuut",inputs.shape)
  85.         print(" hidden_size==", hidden_size)
  86.         # We are not using initial states, but need to pass something to K.rnn funciton
  87.         fake_state = K.zeros_like(inputs)  # <= (batch_size, enc_seq_len, latent_dim
  88.         fake_state = K.sum(fake_state, axis=1)  # <= (batch_size)
  89.         fake_state = K.expand_dims(fake_state)  # <= (batch_size, 1)
  90.         fake_state = K.tile(fake_state, [1, hidden_size])  # <= (batch_size, latent_dim
  91.         return fake_state
  92.  
  93.     fake_state_c = create_inital_state(encoder_out_seq, encoder_out_seq.shape[-1])
  94.     fake_state_e = create_inital_state(encoder_out_seq, encoder_out_seq.shape[1])  # <= (batch_size, enc_seq_len, latent_dim)
  95.  
  96.     """ Computing energy outputs """
  97.     # e_outputs => (batch_size, de_seq_len, en_seq_len)
  98.     last_out, e_outputs, _ = K.rnn(energy_step, decoder_out_seq, [fake_state_e],)
  99.  
  100.     """ Computing context vectors """
  101.     last_out, c_outputs, _ = K.rnn(
  102.         context_step, e_outputs, [fake_state_c],
  103.     )
  104.  
  105.     return c_outputs, e_outputs    
  106. def compute_output_shape(self, input_shape):
  107.     """ Outputs produced by the layer """
  108.     return [
  109.         tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
  110.         tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
  111.            ]
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top