SHARE
TWEET

Untitled

a guest Jul 21st, 2019 70 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class _EncoderModule(nn.Module):
  2.     def __init__(self, embeddings_table, embeddings_size, hidden_size):
  3.         super().__init__()
  4.        
  5.         self.embeddings_table = embeddings_table
  6.         self.embeddings_size = embeddings_size
  7.         self.hidden_size = hidden_size
  8.        
  9.         self.lstm = nn.LSTM(input_size=embeddings_size, hidden_size=hidden_size, bidirectional=True)
  10.  
  11.        
  12.     def forward(self, src_sentences, src_lengths):
  13.         src_embed = self.embeddings_table(src_sentences)
  14.         packd_embed = pack_padded_sequence(src_embed, src_lengths, batch_first=True)
  15.         packd_all_hidden, (hidden_states, cell_states) = self.lstm(packd_embed)
  16.         all_hidden_states, _ = pad_packed_sequence(packd_all_hidden, batch_first=True)
  17.        
  18.         return all_hidden_states, hidden_states, cell_states
  19.  
  20.  
  21. class _DecoderModule(nn.Module):
  22.     def __init__(
  23.             self,
  24.             embeddings_table,
  25.             embeddings_size,
  26.             hidden_size,
  27.             start_idx,
  28.             dropout_prob,
  29.             dst_vocab_size,
  30.             teacher_forcing
  31.         ):
  32.         super().__init__()
  33.        
  34.         self.embeddings_table = embeddings_table
  35.         self.embeddings_size = embeddings_size
  36.         self.hidden_size = hidden_size
  37.         self.start_idx = torch.tensor(start_idx).to(DEVICE)
  38.        
  39.         self.W_h = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  40.         self.W_c = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  41.         self.W_attn = torch.nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
  42.         self.W_u = torch.nn.Linear(3*self.hidden_size, self.hidden_size, bias=False)
  43.         self.W_vocab = torch.nn.Linear(self.hidden_size, len(train_dataset.src_token2id), bias=False)
  44.         self.lstm_cell = torch.nn.LSTMCell(self.hidden_size + self.embeddings_size, self.hidden_size, bias=True)
  45.         self.dropout = torch.nn.Dropout(p=DROPOUT_PROB)
  46.        
  47.         self.teacher_forcing = teacher_forcing
  48.  
  49.        
  50.     def forward(
  51.             self,
  52.             all_enc_hidden_states,
  53.             final_enc_hidden_states,
  54.             final_enc_cell_states,
  55.             max_sentence_length,
  56.             correct_indices=None
  57.         ):
  58.         BATCH_SIZE = all_enc_hidden_states.shape[0]
  59.        
  60.         out = []
  61.        
  62.         proj_final_enc_hidden_states = self.W_h(final_enc_hidden_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
  63.         proj_final_enc_cell_states = self.W_c(final_enc_cell_states.transpose(0, 1).contiguous().view(BATCH_SIZE, 2*HIDDEN_SIZE))
  64.        
  65.         y_t = self.embeddings_table(self.start_idx.repeat(BATCH_SIZE))
  66.         state = (proj_final_enc_hidden_states, proj_final_enc_cell_states)
  67.         o_t = torch.zeros(BATCH_SIZE, HIDDEN_SIZE, device=DEVICE)
  68.         for i in range(max_sentence_length):
  69.             y_barra = torch.cat((y_t, o_t), dim=1)
  70.             state = self.lstm_cell(y_barra, state)
  71.             h_t, _ = state
  72.             e_t = torch.bmm(self.W_attn(all_enc_hidden_states), h_t.unsqueeze(2))
  73.             alfa_t = torch.softmax(e_t, dim=1)   # b_size * m * 1
  74.             a_t = torch.sum(alfa_t * all_enc_hidden_states, dim=1)
  75.            
  76.             u_t = torch.cat((h_t, a_t), dim=1)
  77.             v_t = self.W_u(u_t)
  78.            
  79.             o_t = self.dropout(torch.tanh(v_t))
  80.             P_t = self.W_vocab(o_t)
  81.             out.append(P_t)
  82.             _, max_indices = P_t.max(dim=1)
  83.             if self.teacher_forcing and self.training:        # 100% teacher forcing
  84.                 y_t = self.embeddings_table(correct_indices[:,i])
  85.             else:
  86.                 y_t = self.embeddings_table(max_indices)
  87.        
  88.         return torch.stack(out, dim=1)
  89.        
  90.  
  91. class NeuralMachineTranslator(nn.Module):
  92.     def __init__(
  93.             self,
  94.             src_vocab_size,
  95.             start_idx,
  96.             embeddings_size,
  97.             hidden_size,
  98.             dropout_prob,
  99.             teacher_forcing=False
  100.         ):
  101.         super().__init__()
  102.        
  103.         self.embeddings_table = nn.Embedding(src_vocab_size, embeddings_size, padding_idx=0)
  104.        
  105.         self.encoder_module = _EncoderModule(
  106.             embeddings_table=self.embeddings_table,
  107.             embeddings_size=embeddings_size,
  108.             hidden_size=hidden_size)
  109.         self.decoder_module = _DecoderModule(
  110.             embeddings_table=self.embeddings_table,
  111.             embeddings_size=embeddings_size,
  112.             hidden_size=hidden_size,
  113.             start_idx=start_idx,
  114.             dropout_prob=dropout_prob,
  115.             dst_vocab_size=src_vocab_size,
  116.             teacher_forcing=teacher_forcing)
  117.  
  118.     def forward(
  119.             self,
  120.             src_sentences,
  121.             dst_sentences,    # necesary when using teacher_forcing
  122.             src_lengths,
  123.             dst_lengths
  124.         ):
  125.         max_sentence_length = dst_lengths.max()
  126.        
  127.         # llamar encoder
  128.         all_hidden_states, hidden_states, cell_states = self.encoder_module(src_sentences, src_lengths)
  129.         # llamar decoder
  130.         pre_logits = self.decoder_module(all_hidden_states, hidden_states, cell_states, max_sentence_length, dst_sentences)
  131.        
  132.         # ojo, se devuelven pre-softmax por la funcion crossentropy
  133.         return pre_logits
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top