Advertisement
Guest User

Untitled

a guest
Sep 18th, 2019
154
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.41 KB | None | 0 0
  1. class Decoder(tf.keras.Model):
  2. def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
  3. super(Decoder, self).__init__()
  4. self.batch_sz = batch_sz
  5. self.dec_units = dec_units
  6. self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
  7. self.gru = tf.keras.layers.GRU(self.dec_units,
  8. return_sequences=True,
  9. return_state=True,
  10. recurrent_initializer='glorot_uniform')
  11. self.fc = tf.keras.layers.Dense(vocab_size)
  12.  
  13. # used for attention
  14. self.attention = BahdanauAttention(self.dec_units)
  15.  
  16. def call(self, x, hidden, enc_output):
  17. # enc_output shape == (batch_size, max_length, hidden_size)
  18. context_vector, attention_weights = self.attention(hidden, enc_output)
  19.  
  20. # x shape after passing through embedding == (batch_size, 1, embedding_dim)
  21. x = self.embedding(x)
  22.  
  23. # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
  24. x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
  25.  
  26. # passing the concatenated vector to the GRU
  27. output, state = self.gru(x)
  28.  
  29. # output shape == (batch_size * 1, hidden_size)
  30. output = tf.reshape(output, (-1, output.shape[2]))
  31.  
  32. # output shape == (batch_size, vocab)
  33. x = self.fc(output)
  34.  
  35. return x, state, attention_weights
  36.  
  37.  
  38. decoder = Decoder(vocab_out_size, embedding_dim, units, BATCH_SIZE)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement