Advertisement
Guest User

Untitled

a guest
Jul 21st, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.16 KB | None | 0 0
  1. class _EncoderModule(nn.Module):
  2. def __init__(self, embeddings_table, embeddings_size, hidden_size):
  3. super().__init__()
  4.  
  5. self.embeddings_table = embeddings_table
  6. self.embeddings_size = embeddings_size
  7. self.hidden_size = hidden_size
  8.  
  9. self.lstm = nn.LSTM(input_size=embeddings_size, hidden_size=hidden_size, bidirectional=True)
  10.  
  11.  
  12. def forward(self, src_sentences, src_lengths):
  13. src_embed = self.embeddings_table(src_sentences)
  14. packd_embed = pack_padded_sequence(src_embed, src_lengths, batch_first=True)
  15. packd_all_hidden, (hidden_states, cell_states) = self.lstm(packd_embed)
  16. all_hidden_states, _ = pad_packed_sequence(packd_all_hidden, batch_first=True)
  17.  
  18. return all_hidden_states, hidden_states, cell_states
  19.  
  20.  
  21. class _DecoderModule(nn.Module):
  22. def __init__(
  23. self,
  24. embeddings_table,
  25. embeddings_size,
  26. hidden_size,
  27. start_idx,
  28. dropout_prob,
  29. dst_vocab_size,
  30. teacher_forcing
  31. ):
  32. super().__init__()
  33.  
  34. self.embeddings_table = embeddings_table
  35. self.embeddings_size = embeddings_size
  36. self.hidden_size = hidden_size
  37. self.start_idx = torch.tensor(start_idx).to(DEVICE)
  38.  
  39. self.W_h = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  40. self.W_c = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  41. self.W_attn = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  42. self.W_u = torch.nn.Linear(3*self.hidden_size, self.hidden_size, bias=False)
  43. self.W_vocab = torch.nn.Linear(self.hidden_size, len(train_dataset.src_token2id), bias=False)
  44. self.lstm_cell = torch.nn.LSTMCell(self.hidden_size + self.embeddings_size, self.hidden_size, bias=True)
  45. self.dropout = torch.nn.Dropout(p=DROPOUT_PROB)
  46.  
  47. self.teacher_forcing = teacher_forcing
  48.  
  49.  
  50. def forward(
  51. self,
  52. all_enc_hidden_states,
  53. final_enc_hidden_states,
  54. final_enc_cell_states,
  55. max_sentence_length,
  56. correct_indices=None
  57. ):
  58. BATCH_SIZE = all_enc_hidden_states.shape[0]
  59.  
  60. out = []
  61.  
  62. proj_final_enc_hidden_states = self.W_h(final_enc_hidden_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
  63. proj_final_enc_cell_states = self.W_c(final_enc_cell_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
  64.  
  65. y_t = self.embeddings_table(self.start_idx.repeat(BATCH_SIZE))
  66. state = (proj_final_enc_hidden_states, proj_final_enc_cell_states)
  67. o_t = torch.zeros(BATCH_SIZE, HIDDEN_SIZE, device=DEVICE)
  68. for i in range(max_sentence_length):
  69. y_barra = torch.cat((y_t, o_t), dim=1)
  70. state = self.lstm_cell(y_barra, state)
  71. h_t, _ = state
  72. e_t = torch.bmm(self.W_attn(all_enc_hidden_states), h_t.unsqueeze(2))
  73. alfa_t = torch.softmax(e_t, dim=1) # b_size * m * 1
  74. a_t = torch.sum(alfa_t * all_enc_hidden_states, dim=1)
  75.  
  76. u_t = torch.cat((h_t, a_t), dim=1)
  77. v_t = self.W_u(u_t)
  78.  
  79. o_t = self.dropout(torch.tanh(v_t))
  80. P_t = self.W_vocab(o_t)
  81. out.append(P_t)
  82. _, max_indices = P_t.max(dim=1)
  83. if self.teacher_forcing and self.training: # 100% teacher forcing
  84. y_t = self.embeddings_table(correct_indices[:,i])
  85. else:
  86. y_t = self.embeddings_table(max_indices)
  87.  
  88. return torch.stack(out, dim=1)
  89.  
  90.  
  91. class NeuralMachineTranslator(nn.Module):
  92. def __init__(
  93. self,
  94. src_vocab_size,
  95. start_idx,
  96. embeddings_size,
  97. hidden_size,
  98. dropout_prob,
  99. teacher_forcing=False
  100. ):
  101. super().__init__()
  102.  
  103. self.embeddings_table = nn.Embedding(src_vocab_size, embeddings_size, padding_idx=0)
  104.  
  105. self.encoder_module = _EncoderModule(
  106. embeddings_table=self.embeddings_table,
  107. embeddings_size=embeddings_size,
  108. hidden_size=hidden_size)
  109. self.decoder_module = _DecoderModule(
  110. embeddings_table=self.embeddings_table,
  111. embeddings_size=embeddings_size,
  112. hidden_size=hidden_size,
  113. start_idx=start_idx,
  114. dropout_prob=dropout_prob,
  115. dst_vocab_size=src_vocab_size,
  116. teacher_forcing=teacher_forcing)
  117.  
  118. def forward(
  119. self,
  120. src_sentences,
  121. dst_sentences, # necesary when using teacher_forcing
  122. src_lengths,
  123. dst_lengths
  124. ):
  125. max_sentence_length = dst_lengths.max()
  126.  
  127. # llamar encoder
  128. all_hidden_states, hidden_states, cell_states = self.encoder_module(src_sentences, src_lengths)
  129. # llamar decoder
  130. pre_logits = self.decoder_module(all_hidden_states, hidden_states, cell_states, max_sentence_length, dst_sentences)
  131.  
  132. # ojo, se devuelven pre-softmax por la funcion crossentropy
  133. return pre_logits
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement